Python中快速且易于使用的无限神经网络
项目描述
支持乌克兰! 🇺🇦
思想自由是所有科学的基石。现在,我们的自由正被乌克兰平民轰炸所压制。《不要反对战争——反对战争!supportukrainenow.org》
神经 tangent
ICLR 2020 视频 | 论文 | 快速入门 | 安装指南 | 参考文档 | 发行说明
概述
神经 tangent 是一个高级神经网络 API,用于指定复杂、分层、有限和 无限 宽度的神经网络。神经 tangent 允许研究人员像定义、训练和评估有限网络一样轻松地定义、训练和评估无限网络。该库已被用于 >100 篇 研究论文。
(宽度或通道数)无限神经网络是高斯过程(GPs),其核函数由其架构决定。请参阅 此处 列出的由神经 tangent 的创造者撰写的关于神经网络无限宽度极限的论文。
神经 tangent 允许您从卷积、池化、残差连接、非线性等常见构建块构建神经网络模型,并获得有限模型,还可以获得相应的 GP 的核函数。
该库是用 Python 编写的,使用 JAX 并利用 XLA 来在 CPU、GPU 或 TPU 上运行。核计算高度优化,以速度和内存效率为目标,并且可以自动跨多个加速器进行近似完美的扩展。
神经 tangent 仍在开发中。我们非常欢迎贡献!
内容
Colab 笔记本
通过在 Colaboratory 中玩以下交互式笔记本开始使用神经 tangent 是一种简单的方法。它们展示了神经 tangent 的主要功能,并展示了如何在研究中使用它。
- 神经 tangent 烹饪书
- 权重空间线性化
- 函数空间线性化
- 神经网络相图
- 性能基准:用于 Myrtle 内核 的简单基准。另请参阅 性能
- [新] 经验 NTK
- [新] 自动 NNGP/NTK 的逐元素非线性
安装
要使用 GPU,请首先遵循 JAX 的 GPU 安装说明。否则,通过运行以下命令在 CPU 上安装 JAX:
pip install jax jaxlib --upgrade
安装 JAX 后,运行以下命令安装神经 tangent:
pip install neural-tangents
或者,要从 GitHub 源使用最新版本:
git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .
现在,您可以通过调用以下命令运行示例和测试:
pip install .[testing]
set -e; for f in examples/*.py; do python $f; done # Run examples
set -e; for f in tests/*.py; do python $f; done # Run tests
5 分钟入门
请参阅此 Colab 获取详细的教程。下面是一个非常快速的介绍。
我们的库紧密遵循JAX的API来指定神经网络,stax
。在stax
中,一个网络由一对函数(init_fn, apply_fn)
定义,分别初始化可训练参数和计算网络的输出。下面是一个定义3层网络并计算其输出y
(给定输入x
)的示例。
from jax import random
from jax.example_libraries import stax
init_fn, apply_fn = stax.serial(
stax.Dense(512), stax.Relu,
stax.Dense(512), stax.Relu,
stax.Dense(1)
)
key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)
y = apply_fn(params, x) # (10, 1) jnp.ndarray outputs of the neural network
Neural Tangents被设计为stax
的即插即用替代品,将(init_fn, apply_fn)
元组扩展为三元组(init_fn, apply_fn, kernel_fn)
,其中kernel_fn
是给定架构的无穷网络(GP)的核函数。下面是一个计算两个输入批次x1
和x2
之间GP协方差的示例。
from jax import random
from neural_tangents import stax
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512), stax.Relu(),
stax.Dense(512), stax.Relu(),
stax.Dense(1)
)
key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))
kernel = kernel_fn(x1, x2, 'nngp')
请注意,kernel_fn
可以计算对应于神经网络高斯过程(NNGP)和神经网络切线(NT)核的两个协方差矩阵。NNGP核对应于贝叶斯无穷神经网络。NTK对应于(连续)梯度下降训练的无穷网络。在上面的例子中,我们计算了NNGP核,但我们也可以计算NT或两者。
# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) jnp.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) jnp.ndarray
# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp # True
both.ntk == ntk # True
# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))
此外,如果没有指定第三个参数,则kernel_fn
将返回一个包含附加元数据的Kernel
命名元组。这对于以下方式组合kernel_fn
的应用程序可能很有用
kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)
使用在MSE损失上训练的无穷网络进行推理简化为经典的GP推理,为此我们也提供了方便的工具
import neural_tangents as nt
x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1)) # training targets
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train)
y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) jnp.ndarray test predictions of an infinite Bayesian network
y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) jnp.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)
# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp # True
both.ntk == y_test_ntk # True
# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
无穷宽Resnet
我们可以使用相同的nt.stax
构建块定义一个更复杂、(无穷)宽残差网络
from neural_tangents import stax
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
Main = stax.serial(
stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
channels, (3, 3), strides, padding='SAME')
return stax.serial(stax.FanOut(2),
stax.parallel(Main, Shortcut),
stax.FanInSum())
def WideResnetGroup(n, channels, strides=(1, 1)):
blocks = []
blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
for _ in range(n - 1):
blocks += [WideResnetBlock(channels, (1, 1))]
return stax.serial(*blocks)
def WideResnet(block_size, k, num_classes):
return stax.serial(
stax.Conv(16, (3, 3), padding='SAME'),
WideResnetGroup(block_size, int(16 * k)),
WideResnetGroup(block_size, int(32 * k), (2, 2)),
WideResnetGroup(block_size, int(64 * k), (2, 2)),
stax.AvgPool((8, 8)),
stax.Flatten(),
stax.Dense(num_classes, 1., 0.))
init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)
包描述
neural_tangents
(nt
)包包含以下模块和函数
-
stax
- 构建神经网络(如Conv
、Relu
、serial
、parallel
等)的原始函数 -
predict
- 无穷网络的预测-
predict.gradient_descent_mse
- 使用单无穷宽度/线性化网络在MSE损失上进行训练,并通过连续梯度下降进行任意有限或无穷(t=None
)时间的推理。以闭式形式计算。 -
predict.gradient_descent
- 使用单无穷宽度/线性化网络在任意损失上进行训练,并通过连续(动量)梯度下降进行任意有限时间的推理。使用常微分方程求解器进行计算。 -
predict.gradient_descent_mse_ensemble
- 使用无穷无穷宽度网络集合的推理,可以是完全贝叶斯(get='nngp'
)或使用MSE损失和连续梯度下降进行推理(get='ntk'
)。有限时间的贝叶斯推理(例如,t=1., get='nngp'
)解释为仅对顶层进行梯度下降,因为它收敛到精确的高斯过程推理与NNGP(t=None, get='nngp'
)。以闭式形式计算。 -
predict.gp_inference
- 使用NNGP(get='nngp'
)、NTK(get='ntk'
)或两者(get=('nngp', 'ntk')
)的精确闭式形式高斯过程推理。相当于predict.gradient_descent_mse_ensemble
使用t=None
(无穷训练时间),但API略有不同(接受预计算的核矩阵k_train_train
而不是kernel_fn
和x_train
)。
-
-
monte_carlo_kernel_fn
- 计算任何(init_fn, apply_fn)
的蒙特卡洛核估计,不一定通过nt.stax
指定,使无穷网络的核计算无需闭式表达式。 -
用于研究具有“宽但有限”的神经网络训练动态的工具,如
linearize
、taylor_expand
、empirical_kernel_fn
等。有关详细信息,请参阅宽但有限网络的训练动态。
技术难题
nt.stax
与jax.example_libraries.stax
我们指出以下我们库与JAX库之间的区别。
- 所有
nt.stax
层都是通过函数调用实例化的,即nt.stax.Relu()
与jax.example_libraries.stax.Relu
。 - 所有具有可训练参数的层默认使用NTK参数化。然而,
Dense
和Conv
层也支持通过parameterization
关键字参数使用标准参数化。 nt.stax
和jax.example_libraries.stax
可能有不同的层和选项可用(例如,nt.stax
层支持CIRCULAR
填充,有LayerNorm
,但没有BatchNorm
)。
CPU和TPU性能
对于带有池化的CNN,由于核心利用率低(10-20%,看起来像是XLA:CPU问题)以及过度填充,我们的CPU和TPU性能不佳。我们将研究提高性能,但在此期间建议使用NVIDIA GPU。见性能。
宽但有限网络的训练动态
无限网络的核kernel_fn(x1, x2).ntk
结合nt.predict.gradient_descent_mse
,可以在整个训练过程中解析地跟踪在MSE损失上训练的无限宽神经网络的输出。在这里,我们讨论对宽但有限神经网络的含义,并介绍研究它们在权重空间(网络的训练参数)和函数空间(网络的输出)中演化的工具。
权重空间
无限网络中的连续梯度下降已在中证明与训练一个线性(在可训练参数中)模型相对应,这使得线性化神经网络成为理解宽模型中参数行为的重要研究课题。
为此,我们提供了两个方便的函数
nt.linearize
,和nt.taylor_expand
,
允许我们对任何函数apply_fn(params, x)
在初始参数params_0
周围进行线性化或获得任意阶的泰勒展开,如apply_fn_lin = nt.linearize(apply_fn, params_0)
。
可以像使用任何其他函数一样使用apply_fn_lin(params, x)
(包括作为JAX优化器的输入)。这使得比较神经网络的训练轨迹与其线性化变得容易。先前的理论和实验已经检查了从输入到logits或预激活的神经网络的线性化,而不是从输入到后激活,后激活要非线性得多。
示例
import jax.numpy as jnp
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return jnp.dot(x, W) + b
W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))
apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = jnp.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2
x = jnp.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) jnp.ndarray
函数空间
线性化模型的输出与无限网络的输出完全相同,但具有不同的核——确切地说,是评估在特定apply_fn
上的特定params_0
(网络初始化时使用的参数)的神经切线核。为此,我们提供了nt.empirical_kernel_fn
函数,它接受任何apply_fn
并返回一个kernel_fn(x1, x2, get, params)
,允许在特定的params
上计算经验NTK和/或NNGP(基于get
)核。
示例
import jax.random as random
import jax.numpy as jnp
import neural_tangents as nt
def apply_fn(params, x):
W, b = params
return jnp.dot(x, W) + b
W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))
params = (W_0, b_0)
key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))
kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)
t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) jnp.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
预期结果
线性近似的成功或失败高度依赖于架构。然而,我们观察到的一些经验法则是
-
随着网络大小的增加而收敛。
-
对于全连接网络,当层宽度达到512时,通常观察到非常强的同意(在训练结束时RMSE约为0.05)。
-
对于卷积网络,当通道数达到512时,通常观察到合理的同意。
-
-
在小的学习率下收敛。
因此,对于新模型,建议使用大宽度在小型数据集上使用小学习率开始。
性能
在下面的表格中,我们测量了计算单个 NTK 条目的时间,该条目是在一个 21 层 CNN(3x3
滤波器,无步长,SAME
填充,ReLU
)上进行的,输入形状为 3x32x32
。精确地
layers = []
for _ in range(21):
layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]
带有池化的 CNN
顶层是 stax.GlobalAvgPool()
_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
平台 | 精度 | 毫秒 / NTK 条目 | 最大批量大小(NxN ) |
---|---|---|---|
CPU,>56 个核心,>700 Gb RAM | 32 | 112.90 | >= 128 |
CPU,>56 个核心,>700 Gb RAM | 64 | 258.55 | 95(最快 - 72) |
TPU v2 | 32/16 | 3.2550 | 16 |
TPU v3 | 32/16 | 2.3022 | 24 |
NVIDIA P100 | 32 | 5.9433 | 26 |
NVIDIA P100 | 64 | 11.349 | 18 |
NVIDIA V100 | 32 | 2.7001 | 26 |
NVIDIA V100 | 64 | 6.2058 | 18 |
无池化的 CNN
顶层是 stax.Flatten()
_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
平台 | 精度 | 毫秒 / NTK 条目 | 最大批量大小(NxN ) |
---|---|---|---|
CPU,>56 个核心,>700 Gb RAM | 32 | 0.12013 | 2048 <= N < 4096(最快 - 512) |
CPU,>56 个核心,>700 Gb RAM | 64 | 0.3414 | 2048 <= N < 4096(最快 - 256) |
TPU v2 | 32/16 | 0.0015722 | 512 <= N < 1024 |
TPU v3 | 32/16 | 0.0010647 | 512 <= N < 1024 |
NVIDIA P100 | 32 | 0.015171 | 512 <= N < 1024 |
NVIDIA P100 | 64 | 0.019894 | 512 <= N < 1024 |
NVIDIA V100 | 32 | 0.0046510 | 512 <= N < 1024 |
NVIDIA V100 | 64 | 0.010822 | 512 <= N < 1024 |
使用版本 0.2.1
进行测试。所有 GPU 结果均为单个加速器。请注意,运行时间与您网络的深度成正比。如果您的性能有显著差异,请提交一个错误!
Myrtle 网络
Colab 笔记本 性能基准 展示了如何构建和基准测试内核。为了展示灵活性,我们以 Myrtle 架构 为例。使用 NVIDIA V100
64 位精度,nt
在 Myrtle-5/7/10 内核的完整 60k CIFAR-10 数据集上分别花费了 316/330/508 GPU 小时。
引用
如果您在出版物中使用此代码,请引用我们的论文
# Infinite width NTK/NNGP:
@inproceedings{neuraltangents2020,
title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Learning Representations},
year={2020},
pdf={https://arxiv.org/abs/1912.02803},
url={https://github.com/google/neural-tangents}
}
# Finite width, empirical NTK/NNGP:
@inproceedings{novak2022fast,
title={Fast Finite Width Neural Tangent Kernel},
author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
booktitle={International Conference on Machine Learning},
year={2022},
pdf={https://arxiv.org/abs/2206.08720},
url={https://github.com/google/neural-tangents}
}
# Attention and variable-length inputs:
@inproceedings{hron2020infinite,
title={Infinite attention: NNGP and NTK for deep attention networks},
author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},
booktitle={International Conference on Machine Learning},
year={2020},
pdf={https://arxiv.org/abs/2006.10540},
url={https://github.com/google/neural-tangents}
}
# Infinite-width "standard" parameterization:
@misc{sohl2020on,
title={On the infinite width limit of neural networks with a standard parameterization},
author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},
publisher = {arXiv},
year={2020},
pdf={https://arxiv.org/abs/2001.07301},
url={https://github.com/google/neural-tangents}
}
# Elementwise nonlinearities and sketching:
@inproceedings{han2022fast,
title={Fast Neural Kernel Embeddings for General Activations},
author={Insu Han and Amir Zandieh and Jaehoon Lee and Roman Novak and Lechao Xiao and Amin Karbasi},
booktitle = {Advances in Neural Information Processing Systems},
year={2022},
pdf={https://arxiv.org/abs/2209.04121},
url={https://github.com/google/neural-tangents}
}
项目详情
下载文件
下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。