跳转到主要内容

Python中快速且易于使用的无限神经网络

项目描述

支持乌克兰! 🇺🇦

思想自由是所有科学的基石。现在,我们的自由正被乌克兰平民轰炸所压制。《不要反对战争——反对战争!supportukrainenow.org

神经 tangent

ICLR 2020 视频 | 论文 | 快速入门 | 安装指南 | 参考文档 | 发行说明

PyPI PyPI - Python Version Linux macOS Pytype Coverage Readthedocs

概述

神经 tangent 是一个高级神经网络 API,用于指定复杂、分层、有限和 无限 宽度的神经网络。神经 tangent 允许研究人员像定义、训练和评估有限网络一样轻松地定义、训练和评估无限网络。该库已被用于 >100 篇 研究论文。

(宽度或通道数)无限神经网络是高斯过程(GPs),其核函数由其架构决定。请参阅 此处 列出的由神经 tangent 的创造者撰写的关于神经网络无限宽度极限的论文。

神经 tangent 允许您从卷积、池化、残差连接、非线性等常见构建块构建神经网络模型,并获得有限模型,还可以获得相应的 GP 的核函数。

该库是用 Python 编写的,使用 JAX 并利用 XLA 来在 CPU、GPU 或 TPU 上运行。核计算高度优化,以速度和内存效率为目标,并且可以自动跨多个加速器进行近似完美的扩展。

神经 tangent 仍在开发中。我们非常欢迎贡献!

内容

Colab 笔记本

通过在 Colaboratory 中玩以下交互式笔记本开始使用神经 tangent 是一种简单的方法。它们展示了神经 tangent 的主要功能,并展示了如何在研究中使用它。

安装

要使用 GPU,请首先遵循 JAX 的 GPU 安装说明。否则,通过运行以下命令在 CPU 上安装 JAX:

pip install jax jaxlib --upgrade

安装 JAX 后,运行以下命令安装神经 tangent:

pip install neural-tangents

或者,要从 GitHub 源使用最新版本:

git clone https://github.com/google/neural-tangents; cd neural-tangents
pip install -e .

现在,您可以通过调用以下命令运行示例和测试:

pip install .[testing]
set -e; for f in examples/*.py; do python $f; done  # Run examples
set -e; for f in tests/*.py; do python $f; done  # Run tests

5 分钟入门

请参阅此 Colab 获取详细的教程。下面是一个非常快速的介绍。

我们的库紧密遵循JAX的API来指定神经网络,stax。在stax中,一个网络由一对函数(init_fn, apply_fn)定义,分别初始化可训练参数和计算网络的输出。下面是一个定义3层网络并计算其输出y(给定输入x)的示例。

from jax import random
from jax.example_libraries import stax

init_fn, apply_fn = stax.serial(
    stax.Dense(512), stax.Relu,
    stax.Dense(512), stax.Relu,
    stax.Dense(1)
)

key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)

y = apply_fn(params, x)  # (10, 1) jnp.ndarray outputs of the neural network

Neural Tangents被设计为stax的即插即用替代品,将(init_fn, apply_fn)元组扩展为三元组(init_fn, apply_fn, kernel_fn),其中kernel_fn是给定架构的无穷网络(GP)的核函数。下面是一个计算两个输入批次x1x2之间GP协方差的示例。

from jax import random
from neural_tangents import stax

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(512), stax.Relu(),
    stax.Dense(1)
)

key1, key2 = random.split(random.PRNGKey(1))
x1 = random.normal(key1, (10, 100))
x2 = random.normal(key2, (20, 100))

kernel = kernel_fn(x1, x2, 'nngp')

请注意,kernel_fn可以计算对应于神经网络高斯过程(NNGP)神经网络切线(NT)核的两个协方差矩阵。NNGP核对应于贝叶斯无穷神经网络。NTK对应于(连续)梯度下降训练的无穷网络。在上面的例子中,我们计算了NNGP核,但我们也可以计算NT或两者。

# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) jnp.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) jnp.ndarray

# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
both.nngp == nngp  # True
both.ntk == ntk  # True

# Unpack the kernels namedtuple
nngp, ntk = kernel_fn(x1, x2, ('nngp', 'ntk'))

此外,如果没有指定第三个参数,则kernel_fn将返回一个包含附加元数据的Kernel命名元组。这对于以下方式组合kernel_fn的应用程序可能很有用

kernel = kernel_fn(x1, x2)
kernel = kernel_fn(kernel)
print(kernel.nngp)

使用在MSE损失上训练的无穷网络进行推理简化为经典的GP推理,为此我们也提供了方便的工具

import neural_tangents as nt

x_train, x_test = x1, x2
y_train = random.uniform(key1, shape=(10, 1))  # training targets

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
                                                      y_train)

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) jnp.ndarray test predictions of an infinite Bayesian network

y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) jnp.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
both = predict_fn(x_test=x_test, get=('nngp', 'ntk'))
both.nngp == y_test_nngp  # True
both.ntk == y_test_ntk  # True

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

无穷宽Resnet

我们可以使用相同的nt.stax构建块定义一个更复杂、(无穷)宽残差网络

from neural_tangents import stax

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  Shortcut = stax.Identity() if not channel_mismatch else stax.Conv(
      channels, (3, 3), strides, padding='SAME')
  return stax.serial(stax.FanOut(2),
                     stax.parallel(Main, Shortcut),
                     stax.FanInSum())

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.AvgPool((8, 8)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))

init_fn, apply_fn, kernel_fn = WideResnet(block_size=4, k=1, num_classes=10)

包描述

neural_tangentsnt)包包含以下模块和函数

  • stax - 构建神经网络(如ConvReluserialparallel等)的原始函数

  • predict - 无穷网络的预测

    • predict.gradient_descent_mse - 使用单无穷宽度/线性化网络在MSE损失上进行训练,并通过连续梯度下降进行任意有限或无穷(t=None)时间的推理。以闭式形式计算。

    • predict.gradient_descent - 使用单无穷宽度/线性化网络在任意损失上进行训练,并通过连续(动量)梯度下降进行任意有限时间的推理。使用常微分方程求解器进行计算。

    • predict.gradient_descent_mse_ensemble - 使用无穷无穷宽度网络集合的推理,可以是完全贝叶斯(get='nngp')或使用MSE损失和连续梯度下降进行推理(get='ntk')。有限时间的贝叶斯推理(例如,t=1., get='nngp')解释为仅对顶层进行梯度下降,因为它收敛到精确的高斯过程推理与NNGP(t=None, get='nngp')。以闭式形式计算。

    • predict.gp_inference - 使用NNGP(get='nngp')、NTK(get='ntk')或两者(get=('nngp', 'ntk'))的精确闭式形式高斯过程推理。相当于predict.gradient_descent_mse_ensemble使用t=None(无穷训练时间),但API略有不同(接受预计算的核矩阵k_train_train而不是kernel_fnx_train)。

  • monte_carlo_kernel_fn - 计算任何(init_fn, apply_fn)的蒙特卡洛核估计,不一定通过nt.stax指定,使无穷网络的核计算无需闭式表达式。

  • 用于研究具有“宽但有限”的神经网络训练动态的工具,如linearizetaylor_expandempirical_kernel_fn等。有关详细信息,请参阅宽但有限网络的训练动态

技术难题

nt.staxjax.example_libraries.stax

我们指出以下我们库与JAX库之间的区别。

  • 所有nt.stax层都是通过函数调用实例化的,即nt.stax.Relu()jax.example_libraries.stax.Relu
  • 所有具有可训练参数的层默认使用NTK参数化。然而,DenseConv层也支持通过parameterization关键字参数使用标准参数化
  • nt.staxjax.example_libraries.stax可能有不同的层和选项可用(例如,nt.stax层支持CIRCULAR填充,有LayerNorm,但没有BatchNorm)。

CPU和TPU性能

对于带有池化的CNN,由于核心利用率低(10-20%,看起来像是XLA:CPU问题)以及过度填充,我们的CPU和TPU性能不佳。我们将研究提高性能,但在此期间建议使用NVIDIA GPU。见性能

宽但有限网络的训练动态

无限网络的核kernel_fn(x1, x2).ntk结合nt.predict.gradient_descent_mse,可以在整个训练过程中解析地跟踪在MSE损失上训练的无限宽神经网络的输出。在这里,我们讨论对宽但有限神经网络的含义,并介绍研究它们在权重空间(网络的训练参数)和函数空间(网络的输出)中演化的工具。

权重空间

无限网络中的连续梯度下降已在中证明与训练一个线性(在可训练参数中)模型相对应,这使得线性化神经网络成为理解宽模型中参数行为的重要研究课题。

为此,我们提供了两个方便的函数

  • nt.linearize,和
  • nt.taylor_expand,

允许我们对任何函数apply_fn(params, x)在初始参数params_0周围进行线性化或获得任意阶的泰勒展开,如apply_fn_lin = nt.linearize(apply_fn, params_0)

可以像使用任何其他函数一样使用apply_fn_lin(params, x)(包括作为JAX优化器的输入)。这使得比较神经网络的训练轨迹与其线性化变得容易。先前的理论和实验已经检查了从输入到logits或预激活的神经网络的线性化,而不是从输入到后激活,后激活要非线性得多。

示例

import jax.numpy as jnp
import neural_tangents as nt

def apply_fn(params, x):
  W, b = params
  return jnp.dot(x, W) + b

W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))

apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = jnp.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2

x = jnp.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x)  # (3, 2) jnp.ndarray

函数空间

线性化模型的输出与无限网络的输出完全相同,但具有不同的核——确切地说,是评估在特定apply_fn上的特定params_0(网络初始化时使用的参数)的神经切线核。为此,我们提供了nt.empirical_kernel_fn函数,它接受任何apply_fn并返回一个kernel_fn(x1, x2, get, params),允许在特定的params上计算经验NTK和/或NNGP(基于get)核。

示例

import jax.random as random
import jax.numpy as jnp
import neural_tangents as nt


def apply_fn(params, x):
  W, b = params
  return jnp.dot(x, W) + b


W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))
params = (W_0, b_0)

key1, key2 = random.split(random.PRNGKey(1), 2)
x_train = random.normal(key1, (3, 2))
x_test = random.normal(key2, (4, 2))
y_train = random.uniform(key1, shape=(3, 2))

kernel_fn = nt.empirical_kernel_fn(apply_fn)
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params)
mse_predictor = nt.predict.gradient_descent_mse(ntk_train_train, y_train)

t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) jnp.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent

预期结果

线性近似的成功或失败高度依赖于架构。然而,我们观察到的一些经验法则是

  1. 随着网络大小的增加而收敛。

    • 对于全连接网络,当层宽度达到512时,通常观察到非常强的同意(在训练结束时RMSE约为0.05)。

    • 对于卷积网络,当通道数达到512时,通常观察到合理的同意。

  2. 在小的学习率下收敛。

因此,对于新模型,建议使用大宽度在小型数据集上使用小学习率开始。

性能

在下面的表格中,我们测量了计算单个 NTK 条目的时间,该条目是在一个 21 层 CNN(3x3 滤波器,无步长,SAME 填充,ReLU)上进行的,输入形状为 3x32x32。精确地

layers = []
for _ in range(21):
  layers += [stax.Conv(1, (3, 3), (1, 1), 'SAME'), stax.Relu()]

带有池化的 CNN

顶层是 stax.GlobalAvgPool()

_, _, kernel_fn = stax.serial(*(layers + [stax.GlobalAvgPool()]))
平台 精度 毫秒 / NTK 条目 最大批量大小(NxN
CPU,>56 个核心,>700 Gb RAM 32 112.90 >= 128
CPU,>56 个核心,>700 Gb RAM 64 258.55 95(最快 - 72)
TPU v2 32/16 3.2550 16
TPU v3 32/16 2.3022 24
NVIDIA P100 32 5.9433 26
NVIDIA P100 64 11.349 18
NVIDIA V100 32 2.7001 26
NVIDIA V100 64 6.2058 18

无池化的 CNN

顶层是 stax.Flatten()

_, _, kernel_fn = stax.serial(*(layers + [stax.Flatten()]))
平台 精度 毫秒 / NTK 条目 最大批量大小(NxN
CPU,>56 个核心,>700 Gb RAM 32 0.12013 2048 <= N < 4096(最快 - 512)
CPU,>56 个核心,>700 Gb RAM 64 0.3414 2048 <= N < 4096(最快 - 256)
TPU v2 32/16 0.0015722 512 <= N < 1024
TPU v3 32/16 0.0010647 512 <= N < 1024
NVIDIA P100 32 0.015171 512 <= N < 1024
NVIDIA P100 64 0.019894 512 <= N < 1024
NVIDIA V100 32 0.0046510 512 <= N < 1024
NVIDIA V100 64 0.010822 512 <= N < 1024

使用版本 0.2.1 进行测试。所有 GPU 结果均为单个加速器。请注意,运行时间与您网络的深度成正比。如果您的性能有显著差异,请提交一个错误

Myrtle 网络

Colab 笔记本 性能基准 展示了如何构建和基准测试内核。为了展示灵活性,我们以 Myrtle 架构 为例。使用 NVIDIA V100 64 位精度,nt 在 Myrtle-5/7/10 内核的完整 60k CIFAR-10 数据集上分别花费了 316/330/508 GPU 小时。

引用

如果您在出版物中使用此代码,请引用我们的论文

# Infinite width NTK/NNGP:
@inproceedings{neuraltangents2020,
    title={Neural Tangents: Fast and Easy Infinite Neural Networks in Python},
    author={Roman Novak and Lechao Xiao and Jiri Hron and Jaehoon Lee and Alexander A. Alemi and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Learning Representations},
    year={2020},
    pdf={https://arxiv.org/abs/1912.02803},
    url={https://github.com/google/neural-tangents}
}

# Finite width, empirical NTK/NNGP:
@inproceedings{novak2022fast,
    title={Fast Finite Width Neural Tangent Kernel},
    author={Roman Novak and Jascha Sohl-Dickstein and Samuel S. Schoenholz},
    booktitle={International Conference on Machine Learning},
    year={2022},
    pdf={https://arxiv.org/abs/2206.08720},
    url={https://github.com/google/neural-tangents}
}

# Attention and variable-length inputs:
@inproceedings{hron2020infinite,
    title={Infinite attention: NNGP and NTK for deep attention networks},
    author={Jiri Hron and Yasaman Bahri and Jascha Sohl-Dickstein and Roman Novak},
    booktitle={International Conference on Machine Learning},
    year={2020},
    pdf={https://arxiv.org/abs/2006.10540},
    url={https://github.com/google/neural-tangents}
}

# Infinite-width "standard" parameterization:
@misc{sohl2020on,
    title={On the infinite width limit of neural networks with a standard parameterization},
    author={Jascha Sohl-Dickstein and Roman Novak and Samuel S. Schoenholz and Jaehoon Lee},
    publisher = {arXiv},
    year={2020},
    pdf={https://arxiv.org/abs/2001.07301},
    url={https://github.com/google/neural-tangents}
}

# Elementwise nonlinearities and sketching:
@inproceedings{han2022fast,
    title={Fast Neural Kernel Embeddings for General Activations},
    author={Insu Han and Amir Zandieh and Jaehoon Lee and Roman Novak and Lechao Xiao and Amin Karbasi},
    booktitle = {Advances in Neural Information Processing Systems},
    year={2022},
    pdf={https://arxiv.org/abs/2209.04121},
    url={https://github.com/google/neural-tangents}
}

下载文件

下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源分发

neural-tangents-0.6.5.tar.gz (215.5 kB 查看哈希值)

上传时间

构建分发

neural_tangents-0.6.5-py2.py3-none-any.whl (248.7 kB 查看哈希值)

上传时间 Python 2 Python 3

由以下支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面