跳转到主要内容

jaxsnn是一种基于事件、受机器学习启发的SNN(突触神经网络)训练和模拟方法,包括对神经形态后端(BrainScaleS-2)的支持。

项目描述

jaxsnn

jaxsnn 是一种基于事件、受机器学习启发的SNN(突触神经网络)训练和模拟方法,包括对神经形态后端(BrainScaleS-2)的支持。我们基于 jax 构建,它是提供autograd和XLA功能用于高性能机器学习研究的Python库。

构建软件

该软件基于现有库,例如 jaxoptaxtree-math。当使用神经形态BrainScaleS-2后端时,需要平台软件栈。

我们提供包含所有构建时和运行时依赖项的容器镜像(基于 Singularity格式)。您可以从中下载最新版本:此处

对于所有后续步骤,我们假设最新的Singularity容器位于 /containers/stable/latest

基于Github的构建

要从公共资源构建此项目,请遵循以下指南

# 1) Most of the following steps will be executed within a singularity container
#    To keep the steps clutter-free, we start by defining an alias
shopt -s expand_aliases
alias c="singularity exec --app dls /containers/stable/latest"

# 2) Prepare a fresh workspace and change directory into it
mkdir workspace && cd workspace

# 3) Fetch a current copy of the symwaf2ic build tool
git clone https://github.com/electronicvisions/waf -b symwaf2ic symwaf2ic

# 4) Build symwaf2ic
c make -C symwaf2ic
ln -s symwaf2ic/waf

# 5) Setup your workspace and clone all dependencies (--clone-depth=1 to skip history)
c ./waf setup --repo-db-url=https://github.com/electronicvisions/projects --project=jaxsnn

# 6) Load PPU cross-compiler toolchain (or build https://github.com/electronicvisions/oppulance)
module load ppu-toolchain

# 7) Build the project
#    Adjust -j1 to your own needs, beware that high parallelism will increase memory consumption!
c ./waf configure
c ./waf build -j1

# 8) Install the project to ./bin and ./lib
c ./waf install

# 9) If you run programs outside waf, you'll need to add ./lib and ./bin to your path specifications
export SINGULARITYENV_PREPEND_PATH=`pwd`/bin:$SINGULARITYENV_PREPEND_PATH
export SINGULARITYENV_LD_LIBRARY_PATH=`pwd`/lib:$SINGULARITYENV_LD_LIBRARY_PATH
export PYTHONPATH=`pwd`/lib:$PYTHONPATH

结构

jaxsnn 分为两部分。SNN的培训以init/apply风格进行。

时间离散

jaxsnn.discrete 通过以离散方式处理时间来模拟 SNN。它使用固定大小的欧拉步来推进网络向前在时间上前进,这受到了 norse 的启发。

时间连续

jaxsnn.event 对时间进行连续处理,并允许从一个事件跳转到下一个事件。其核心功能包括 step 函数,该函数执行以下三件事:

  1. 找到下一个阈值交叉点
  2. 将神经元积分到当前时间点
  3. 在阈值交叉后应用不连续性

jaxsnn.event.leaky_integrate_and_fire 提供多种神经元类型,可用于构建更大的网络。每种神经元类型定义了上述提到的三个函数。

BSS-2 连接

jaxsnn.event.hardware 提供了连接到 BSS-2 系统 的功能,并在专用类神经形态硬件上进行学习实验。

第一步

我们提供了多个 jaxsnn 的使用示例。

使用代理梯度在 Yin-Yang 数据集上进行的时离散学习

python -m jaxsnn.discrete.tasks.yinyang

基于事件的二层前馈网络,具有解析梯度

python -m jaxsnn.event.tasks.yinyang_analytical

基于事件的循环网络(权重设置为模拟二层前馈网络),使用 EventProp 算法计算梯度

python -m jaxsnn.event.tasks.yinyang_event_prop

BSS-2

如果您想使用 BSS-2 系统,提供了一个工作示例

python -m jaxsnn.event.tasks.hardware.yinyang

操作点校准脚本为 src/pyjaxsnn/jaxsnn/event/hardware/calib/neuron_calib.py。示例

srun -p cube --wafer 69 --fpga-without-aout 0 --pty c python ./neuron_calib.py \
	--wafer           W69F0 \
	--threshold         150 \
	--tau-syn          6e-6 \
	--tau-mem         12e-6 \
	--refractory-time 30e-6 \
	--synapse-dac-bias 1000
	--calib-dir src/pyjaxsnn/jaxsnn/event/hardware/calib

如果您想研究不同的硬件伪影(如 BSS-2 上的尖峰时间噪声)对 SNN 性能的影响,请参阅此示例

python -m jaxsnn.event.tasks.hardware.yinyang_mock

您可以在 BSS-2 上的实际执行和纯软件模拟模式之间切换,在这种模式下,硬件由第二个软件网络模拟。您可以向第一个网络中的尖峰添加噪声或限制动态范围(如 BSS-2 上那样)。

文档

多个笔记本帮助您开始使用 jaxsnn

  • event_based_snn.ipynb 详细介绍了如何在 JAX 中使用基于事件的软件进行基于梯度的 SNN 学习
  • ttfs.ipynb 探索了如何解析地计算尖峰时间以及如何构建由 LIF 神经元组成的小型网络
  • event_prop.ipynb 比较了 EventProp 算法的梯度与解析梯度(TTFS)

待办事项

  • 数值:在 EventPropLIF 神经元模块中,梯度目前无法正确地跨越多层。这个问题是因为在 custom_vjp 中没有正确调整输入队列的状态。因此,只能通过使用 RecurrentEventPropLIF 递归地定义多层网络。
  • 硬件神经元模块 HardwareRecurrentLIF(可以模拟多个前馈层)与种群/投影之间的映射尚未干净地实现,而是被硬编码到任务中(实验返回两个层的尖峰列表,这些列表被合并在一起,投影是硬编码的)。
  • 目前,在每个任务和实验中,都会向硬件的尖峰数据添加少量噪声。这是因为 jaxsnn 梯度计算无法处理多个时间点完全相同的尖峰,这在 BSS-2 上由于周期分辨率可能会发生。这应该直接移动到 experiment 类中,或者软件应该进行调整以处理这种情况。
  • 绘图:当前绘图不从保存的数据中加载,而是在每个任务的末尾运行。它应该设置为独立运行,并从文件中加载数据。

致谢

此存储库中的软件是由海德堡大学的员工和学生开发的,作为电子视觉(s)小组在基尔霍夫物理研究所进行的科研工作的组成部分。

本研究获得了欧盟地平线2020框架计划下的资助,资助协议为785907(HBP SGA2)和945539(HBP SGA3),德国研究联合会(DFG,德国研究基金会)在德国卓越战略EXC 2181/1-390900948(海德堡结构卓越集群)的资助下,德国联邦教育与研究部在项目编号16ES1127的框架内作为Pilotinnovationswettbewerb Energieeffizientes KI-System的一部分,以及来自赫尔姆霍兹协会倡议和网络基金[高级计算架构(ACA)]下的项目SO-092,以及来自曼弗雷德·斯特拉克基金会和2018年劳滕施莱格尔研究奖(授予卡尔海因茨·梅耶)。

许可

SPDX-License-Identifier: LGPL-2.1-or-later

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码发行版

jaxsnn-0.1.0.tar.gz (65.1 kB 查看哈希值

上传时间 源代码

构建发行版

jaxsnn-0.1.0-py3-none-any.whl (88.4 kB 查看哈希值

上传时间 Python 3

支持