jaxsnn是一种基于事件、受机器学习启发的SNN(突触神经网络)训练和模拟方法,包括对神经形态后端(BrainScaleS-2)的支持。
项目描述
jaxsnn
jaxsnn
是一种基于事件、受机器学习启发的SNN(突触神经网络)训练和模拟方法,包括对神经形态后端(BrainScaleS-2)的支持。我们基于 jax 构建,它是提供autograd和XLA功能用于高性能机器学习研究的Python库。
构建软件
该软件基于现有库,例如 jax、optax 和 tree-math。当使用神经形态BrainScaleS-2后端时,需要平台软件栈。
我们提供包含所有构建时和运行时依赖项的容器镜像(基于 Singularity格式)。您可以从中下载最新版本:此处。
对于所有后续步骤,我们假设最新的Singularity容器位于 /containers/stable/latest
。
基于Github的构建
要从公共资源构建此项目,请遵循以下指南
# 1) Most of the following steps will be executed within a singularity container
# To keep the steps clutter-free, we start by defining an alias
shopt -s expand_aliases
alias c="singularity exec --app dls /containers/stable/latest"
# 2) Prepare a fresh workspace and change directory into it
mkdir workspace && cd workspace
# 3) Fetch a current copy of the symwaf2ic build tool
git clone https://github.com/electronicvisions/waf -b symwaf2ic symwaf2ic
# 4) Build symwaf2ic
c make -C symwaf2ic
ln -s symwaf2ic/waf
# 5) Setup your workspace and clone all dependencies (--clone-depth=1 to skip history)
c ./waf setup --repo-db-url=https://github.com/electronicvisions/projects --project=jaxsnn
# 6) Load PPU cross-compiler toolchain (or build https://github.com/electronicvisions/oppulance)
module load ppu-toolchain
# 7) Build the project
# Adjust -j1 to your own needs, beware that high parallelism will increase memory consumption!
c ./waf configure
c ./waf build -j1
# 8) Install the project to ./bin and ./lib
c ./waf install
# 9) If you run programs outside waf, you'll need to add ./lib and ./bin to your path specifications
export SINGULARITYENV_PREPEND_PATH=`pwd`/bin:$SINGULARITYENV_PREPEND_PATH
export SINGULARITYENV_LD_LIBRARY_PATH=`pwd`/lib:$SINGULARITYENV_LD_LIBRARY_PATH
export PYTHONPATH=`pwd`/lib:$PYTHONPATH
结构
jaxsnn
分为两部分。SNN的培训以init/apply风格进行。
时间离散
jaxsnn.discrete
通过以离散方式处理时间来模拟 SNN。它使用固定大小的欧拉步来推进网络向前在时间上前进,这受到了 norse 的启发。
时间连续
jaxsnn.event
对时间进行连续处理,并允许从一个事件跳转到下一个事件。其核心功能包括 step
函数,该函数执行以下三件事:
- 找到下一个阈值交叉点
- 将神经元积分到当前时间点
- 在阈值交叉后应用不连续性
jaxsnn.event.leaky_integrate_and_fire
提供多种神经元类型,可用于构建更大的网络。每种神经元类型定义了上述提到的三个函数。
BSS-2 连接
jaxsnn.event.hardware
提供了连接到 BSS-2 系统 的功能,并在专用类神经形态硬件上进行学习实验。
第一步
我们提供了多个 jaxsnn
的使用示例。
使用代理梯度在 Yin-Yang 数据集上进行的时离散学习
python -m jaxsnn.discrete.tasks.yinyang
基于事件的二层前馈网络,具有解析梯度
python -m jaxsnn.event.tasks.yinyang_analytical
基于事件的循环网络(权重设置为模拟二层前馈网络),使用 EventProp 算法计算梯度
python -m jaxsnn.event.tasks.yinyang_event_prop
BSS-2
如果您想使用 BSS-2 系统,提供了一个工作示例
python -m jaxsnn.event.tasks.hardware.yinyang
操作点校准脚本为 src/pyjaxsnn/jaxsnn/event/hardware/calib/neuron_calib.py
。示例
srun -p cube --wafer 69 --fpga-without-aout 0 --pty c python ./neuron_calib.py \
--wafer W69F0 \
--threshold 150 \
--tau-syn 6e-6 \
--tau-mem 12e-6 \
--refractory-time 30e-6 \
--synapse-dac-bias 1000
--calib-dir src/pyjaxsnn/jaxsnn/event/hardware/calib
如果您想研究不同的硬件伪影(如 BSS-2 上的尖峰时间噪声)对 SNN 性能的影响,请参阅此示例
python -m jaxsnn.event.tasks.hardware.yinyang_mock
您可以在 BSS-2 上的实际执行和纯软件模拟模式之间切换,在这种模式下,硬件由第二个软件网络模拟。您可以向第一个网络中的尖峰添加噪声或限制动态范围(如 BSS-2 上那样)。
文档
多个笔记本帮助您开始使用 jaxsnn
。
event_based_snn.ipynb
详细介绍了如何在 JAX 中使用基于事件的软件进行基于梯度的 SNN 学习ttfs.ipynb
探索了如何解析地计算尖峰时间以及如何构建由 LIF 神经元组成的小型网络event_prop.ipynb
比较了 EventProp 算法的梯度与解析梯度(TTFS)
待办事项
- 数值:在
EventPropLIF
神经元模块中,梯度目前无法正确地跨越多层。这个问题是因为在custom_vjp
中没有正确调整输入队列的状态。因此,只能通过使用RecurrentEventPropLIF
递归地定义多层网络。 - 硬件神经元模块
HardwareRecurrentLIF
(可以模拟多个前馈层)与种群/投影之间的映射尚未干净地实现,而是被硬编码到任务中(实验返回两个层的尖峰列表,这些列表被合并在一起,投影是硬编码的)。 - 目前,在每个任务和实验中,都会向硬件的尖峰数据添加少量噪声。这是因为
jaxsnn
梯度计算无法处理多个时间点完全相同的尖峰,这在 BSS-2 上由于周期分辨率可能会发生。这应该直接移动到experiment
类中,或者软件应该进行调整以处理这种情况。 - 绘图:当前绘图不从保存的数据中加载,而是在每个任务的末尾运行。它应该设置为独立运行,并从文件中加载数据。
致谢
此存储库中的软件是由海德堡大学的员工和学生开发的,作为电子视觉(s)小组在基尔霍夫物理研究所进行的科研工作的组成部分。
本研究获得了欧盟地平线2020框架计划下的资助,资助协议为785907(HBP SGA2)和945539(HBP SGA3),德国研究联合会(DFG,德国研究基金会)在德国卓越战略EXC 2181/1-390900948(海德堡结构卓越集群)的资助下,德国联邦教育与研究部在项目编号16ES1127的框架内作为Pilotinnovationswettbewerb Energieeffizientes KI-System的一部分,以及来自赫尔姆霍兹协会倡议和网络基金[高级计算架构(ACA)]下的项目SO-092,以及来自曼弗雷德·斯特拉克基金会和2018年劳滕施莱格尔研究奖(授予卡尔海因茨·梅耶)。
许可
SPDX-License-Identifier: LGPL-2.1-or-later
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源代码发行版
构建发行版
jaxsnn-0.1.0.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 34811b4077e02afa418506f94cbaaeb4344948c5645390bc8baba1859cc96b03 |
|
MD5 | 795c803abb83b76b4f283b5ad602a5e7 |
|
BLAKE2b-256 | 232efa5d96107170826b7b7e086ade4cbeda83dd4bdfd7eaee870fb0eab1c56d |
jaxsnn-0.1.0-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 1a630a4e84f133166136f8a9fdb6c46bfb5320385a6422002f1b3913ac37265c |
|
MD5 | 8e5d81abcaf79d145ce895d05fefa888 |
|
BLAKE2b-256 | eec38e3468b35ca92a75e37e96632883a34adc9095a54fc2d36cd645f967172d |