Haiku是一个用于在JAX中构建神经网络库。
项目描述
Haiku: Sonnet for JAX
概述 | 为什么选择Haiku? | 快速入门 | 安装 | 示例 | 用户手册 | 文档 | 引用Haiku
[!IMPORTANT] 📣 截至2023年7月 Google DeepMind 建议新项目采用 Flax 而不是Haiku。 Flax 是一个由 Google Brain 原先开发并由 Google DeepMind 现在维护的神经网络库。 📣
在撰写本文时,Flax 已具有 Haiku 所具备的所有功能,并且拥有一个更大的、更活跃的 开发团队,以及更多来自 Alphabet 之外的用户采用。 Flax 拥有更详尽的文档、示例,以及一个活跃的社区,它们正在创建端到端的示例。
Haiku 将继续尽力支持,但项目将进入 维护模式,这意味着开发工作将集中在错误修复和与 JAX 新版本兼容性上。
将推出新版本以保持 Haiku 与 Python 和 JAX 的新版本兼容,但我们不会添加(或接受关于)新功能。
在 Google DeepMind 内部有大量的 Haiku 使用,目前计划无限期地以这种模式支持 Haiku。
什么是 Haiku?
Haiku 是由 Sonnet 的作者之一开发的用于 JAX 的简单神经网络库,Sonnet 是一个用于 TensorFlow 的神经网络库。
有关 Haiku 的文档可在 https://haiku.jax.net.cn/ 找到。
辨析:如果您在寻找操作系统 Haiku,请参阅 https://haiku-os.org/。
概述
JAX 是一个数值计算库,它结合了 NumPy、自动微分以及一流的 GPU/TPU 支持。
Haiku 是一个简单的 JAX 神经网络库,它允许用户在熟悉的对象导向编程模型的同时,允许完全访问 JAX 的纯函数转换。
Haiku 提供了两个核心工具:模块抽象 hk.Module
和简单函数转换 hk.transform
。
hk.Module
是包含对其自身参数、其他模块以及应用在用户输入上的函数的引用的 Python 对象。
hk.transform
将使用这些面向对象、功能上“不纯”的模块的函数转换为可由 jax.jit
、jax.grad
、jax.pmap
等使用的纯函数。
为什么选择 Haiku?
存在许多 JAX 的神经网络库。为什么您应该选择 Haiku?
Haiku 已由 DeepMind 的研究人员进行规模测试。
- DeepMind 已经在 Haiku 和 JAX 中相对容易地重现了许多实验,包括图像和语言处理、生成模型和强化学习的大规模结果。
Haiku 是一个库,而不是一个框架。
- Haiku 设计来简化特定的事情:管理模型参数和其他模型状态。
- Haiku 预期可以与其他库协同工作,并与 JAX 的其余部分良好协作。
- Haiku 否则设计来避免麻烦 - 它不定义自定义优化器、检查点格式或复制 API。
Haiku 不会重造轮子。
- Haiku 基于 Sonnet 的编程模型和 API 构建起来,Sonnet 是一个在 DeepMind 中具有广泛采用的神经网络库。它保留了 Sonnet 的基于
Module
的编程模型以管理状态,同时保留了访问 JAX 的函数转换。 - Haiku API 和抽象尽可能接近 Sonnet。许多用户已经发现 Sonnet 是 TensorFlow 中的高效编程模型;Haiku 在 JAX 中实现了同样的体验。
过渡到 Haiku 很容易。
- 按照设计,从 TensorFlow 和 Sonnet 过渡到 JAX 和 Haiku 很容易。
- 除了新功能(例如
hk.transform
)之外,Haiku 致力于匹配 Sonnet 2 的 API。模块、方法、参数名、默认值和初始化方案应该匹配。
Haiku 使 JAX 的其他方面更加简单。
- Haiku 提供了一个用于处理随机数的简单模型。在一个变换函数中,
hk.next_rng_key()
返回一个唯一的 rng 键。 - 这些唯一键是从传递给顶级变换函数的初始随机键中确定的,因此可以安全地与 JAX 程序转换一起使用。
快速入门
让我们看一下一个示例神经网络、损失函数和训练循环。(有关更多示例,请参阅我们的 示例目录。MNIST 示例是一个良好的起点。)
import haiku as hk
import jax.numpy as jnp
def softmax_cross_entropy(logits, labels):
one_hot = jax.nn.one_hot(labels, logits.shape[-1])
return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)
def loss_fn(images, labels):
mlp = hk.Sequential([
hk.Linear(300), jax.nn.relu,
hk.Linear(100), jax.nn.relu,
hk.Linear(10),
])
logits = mlp(images)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
params = loss_fn_t.init(rng, dummy_images, dummy_labels)
def update_rule(param, update):
return param - 0.01 * update
for images, labels in input_dataset:
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree_util.tree_map(update_rule, params, grads)
Haiku 的核心是 hk.transform
。transform
函数允许您编写依赖于参数(这里为 Linear
层的权重)的神经网络函数,而不需要您显式编写初始化这些参数的样板代码。transform
通过将函数转换为两个函数来实现这一点,这两个函数是 纯(由 JAX 所需)init
和 apply
。
init
init
函数,其签名为 params = init(rng, ...)
(其中 ...
是未变换函数的参数),允许您 收集 网络中任何参数的初始值。Haiku 通过运行您的函数,跟踪通过 hk.get_parameter
(例如由 hk.Linear
调用)请求的任何参数,并将它们返回给您来实现这一点。
返回的 params
对象是一个嵌套的数据结构,包含您的网络中所有的参数,旨在供您检查和操作。具体而言,它是一个将模块名称映射到模块参数的映射,其中模块参数是参数名称到参数值的映射。例如
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
函数,其签名为 result = apply(params, rng, ...)
,允许您 注入 参数值到您的函数中。每当调用 hk.get_parameter
时,返回的值将来自您传递给 apply
的 params
。
loss = loss_fn_t.apply(params, rng, images, labels)
请注意,由于我们的损失函数所执行的实际上计算不依赖于随机数,因此传递随机数生成器是不必要的,因此我们也可以为 rng
参数传递 None
。(注意,如果您的计算确实使用随机数,则传递 None
给 rng
将引发错误。)在我们的示例中,我们通过以下方式让 Haiku 自动执行此操作:
loss_fn_t = hk.without_apply_rng(loss_fn_t)
由于 apply
是一个纯函数,我们可以将其传递给 jax.grad
(或 JAX 的其他任何转换)
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
训练
本例中的训练循环非常简单。需要注意的一个细节是使用 jax.tree_util.tree_map
在 params
和 grads
中所有匹配的条目上应用 sgd
函数。结果具有与之前相同的结构,并且可以再次与 apply
一起使用。
安装
Haiku 使用纯 Python 编写,但通过 JAX 依赖于 C++ 代码。
由于 JAX 安装取决于您的 CUDA 版本,因此 Haiku 未在 requirements.txt
中列出 JAX 作为依赖项。
首先,遵循 这些说明 使用相关加速器支持安装 JAX。
然后,使用 pip 安装 Haiku
$ pip install git+https://github.com/deepmind/dm-haiku
或者,您也可以通过 PyPI 进行安装
$ pip install -U dm-haiku
我们的示例依赖于其他库(例如 bsuite)。您可以使用 pip 安装完整的附加要求集
$ pip install -r examples/requirements.txt
用户手册
编写自己的模块
在Haiku中,所有模块都是hk.Module
的子类。您可以实现任何喜欢的函数(没有任何特殊处理),但通常模块会实现__init__
和__call__
。
让我们通过实现一个线性层来探讨实现过程。
class MyLinear(hk.Module):
def __init__(self, output_size, name=None):
super().__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
return jnp.dot(x, w) + b
所有模块都有一个名称。当没有传递name
参数给模块时,其名称将从Python类的名称推断出来(例如MyLinear
变为my_linear
)。模块可以有命名参数,这些参数可以通过hk.get_parameter(param_name, ...)
访问。我们使用此API(而不是仅使用对象属性),以便可以使用hk.transform
将您的代码转换为纯函数。
使用模块时,您需要定义函数,并使用hk.transform
将它们转换为纯函数对。有关transform
返回的函数的更多详细信息,请参阅我们的快速入门。
def forward_fn(x):
model = MyLinear(10)
return model(x)
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)
x = jnp.ones([1, 1])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward.apply(params, None, x)
与随机模型一起工作
某些模型在计算过程中可能需要随机抽样。例如,在具有重参数化技巧的变分自编码器中,需要一个从标准正态分布的随机样本。对于dropout,我们需要一个随机掩码来丢弃输入中的单元。在JAX中使这工作的主要障碍在于PRNG密钥的管理。
在Haiku中,我们提供了一个简单的API来维护与模块关联的PRNG密钥序列:hk.next_rng_key()
(或next_rng_keys()
用于多个密钥)。
class MyDropout(hk.Module):
def __init__(self, rate=0.5, name=None):
super().__init__(name=name)
self.rate = rate
def __call__(self, x):
key = hk.next_rng_key()
p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
return x * p / (1.0 - self.rate)
forward = hk.transform(lambda x: MyDropout()(x))
key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)
有关与随机模型一起工作的更完整说明,请参阅我们的VAE示例。
注意:hk.next_rng_key()
不是纯函数,这意味着您应避免在hk.transform
内的JAX转换中使用它。有关更多信息以及可能的解决方案,请参阅Haiku转换文档以及可用Haiku网络内JAX转换的包装器。
与不可训练的状态一起工作
某些模型可能希望维护一些内部、可变的状态。例如,在批量归一化中,维护训练期间遇到的值的移动平均值。
在Haiku中,我们提供了一个简单的API来维护与模块关联的可变状态:hk.set_state
和hk.get_state
。当使用这些函数时,您需要使用hk.transform_with_state
来转换您的函数,因为返回的函数对的签名不同。
def forward(x, is_training):
net = hk.nets.ResNet50(1000)
return net(x, is_training)
forward = hk.transform_with_state(forward)
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)
如果您忘记使用hk.transform_with_state
,不要担心,我们将打印一个清晰的错误,指向hk.transform_with_state
,而不是默默地丢弃您的状态。
使用jax.pmap
进行分布式训练
从hk.transform
(或hk.transform_with_state
)返回的纯函数与jax.pmap
完全兼容。有关使用jax.pmap
进行SPMD编程的更多详细信息,请参阅此处。
使用Haiku的jax.pmap
的一个常见用途是在许多加速器上进行数据并行训练,可能跨越多个主机。使用Haiku,这可能会看起来像这样
def loss_fn(inputs, labels):
logits = hk.nets.MLP([8, 4, 2])(x)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)
# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)
def make_superbatch():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [next(input_dataset) for _ in range(num_devices)]
superbatch_images, superbatch_labels = zip(*superbatch)
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np.stack(superbatch_images)
superbatch_labels = np.stack(superbatch_labels)
return superbatch_images, superbatch_labels
def update(params, inputs, labels, axis_name='i'):
"""Updates params based on performance on inputs and labels."""
grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
# Take the mean of the gradients across all data-parallel replicas.
grads = jax.lax.pmean(grads, axis_name)
# Update parameters using SGD or Adam or ...
new_params = my_update_rule(params, grads)
return new_params
# Run several training updates.
for _ in range(10):
superbatch_images, superbatch_labels = make_superbatch()
params = jax.pmap(update, axis_name='i')(params, superbatch_images,
superbatch_labels)
有关分布式Haiku训练的更完整说明,请查看我们的ResNet-50在ImageNet示例。
引用Haiku
要引用此存储库
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.10},
year = {2020},
}
在此bibtex条目中,版本号旨在从haiku/__init__.py
获取,年份对应于项目的开源发布年份。
项目详细信息
下载文件
下载适用于您平台的应用程序文件。如果您不确定选择哪个,请了解更多关于安装包的信息。