微分、编译和转换Numpy代码。
项目描述
可变换的规模数值计算
快速入门 | 变换 | 安装指南 | 神经网络库 | 变更日志 | 参考文档
JAX是什么?
JAX是一个Python库,用于面向加速器的数组计算和程序转换,旨在进行高性能数值计算和大规模机器学习。
凭借其更新的Autograd版本,JAX可以自动微分原生的Python和NumPy函数。它可以微分循环、分支、递归和闭包,并且可以求出多级导数。它支持通过grad
进行的逆模式微分(即反向传播)以及前向模式微分,并且两者可以任意组合到任何级别。
新功能是JAX使用XLA来编译和运行您的NumPy程序在GPU和TPU上。默认情况下,编译在幕后进行,库调用得到即时编译和执行。但是JAX还允许您使用一个函数API jit
将您的Python函数即时编译成XLA优化的内核。编译和自动微分可以任意组合,因此您可以在不离开Python的情况下表达复杂的算法并获得最佳性能。您甚至可以使用pmap
同时编程多个GPU或TPU核心,并对整个过程进行微分。
深入了解,你会发现JAX实际上是一个可扩展的系统,用于可组合函数转换。grad
和jit
都是这种转换的实例。其他包括vmap
用于自动向量化以及pmap
用于多加速器的单程序多数据(SPMD)并行编程,未来还有更多。
这是一个研究项目,不是官方的Google产品。请期待可能出现错误和尖锐边缘。请通过试用、报告错误以及告诉我们您的想法来帮助我们!
import jax.numpy as jnp
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
def loss(params, inputs, targets):
preds = predict(params, inputs)
return jnp.sum((preds - targets)**2)
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
内容
快速入门:云中的Colab
直接使用浏览器中的笔记本开始,连接到Google Cloud GPU。以下是一些入门笔记本
JAX现在可以在Cloud TPUs上运行。要试用预览,请参阅Cloud TPU Colabs。
深入了解JAX
变换
在核心上,JAX是一个用于转换数值函数的可扩展系统。以下四个转换是主要感兴趣的:grad
,jit
,vmap
和pmap
。
使用grad
进行自动微分
JAX与Autograd大致具有相同的API。最流行的函数是grad
,用于逆模式梯度
from jax import grad
import jax.numpy as jnp
def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
您可以使用grad
进行任意阶的微分。
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
对于更高级的自动微分,您可以使用 jax.vjp
进行逆模式向量-雅可比乘积,以及 jax.jvp
进行前模式雅可比-向量乘积。这两个可以任意组合,也可以与其他 JAX 转换一起使用。以下是一种组合方法,用于高效地计算 完整的 Hessian 矩阵
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
与 Autograd 一样,您可以使用 Python 控制结构进行微分
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
请参阅有关自动微分的 参考文档 和 JAX 自动微分烹饪书 以获取更多信息。
使用 jit
进行编译
您可以使用 XLA 使用 jit
编译您的函数,jit
可以用作 @jit
装饰器或作为高阶函数。
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
您可以根据喜好混合 jit
和 grad
以及任何其他 JAX 转换。
使用 jit
对函数可以使用的 Python 控制流类型施加约束;请参阅 常见错误笔记本 以获取更多信息。
使用 vmap
进行自动向量化
vmap
是向量化的映射。它具有将函数映射到数组轴的熟悉语义,但与将循环保留在外部不同,它将循环推入函数的基本操作以获得更好的性能。
使用 vmap
可以让您在代码中省略批量维度。例如,考虑这个简单的 未批处理 神经网络预测函数
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
activations = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
我们通常编写 jnp.dot(activations, W)
以允许在 activations
的左侧具有批量维度,但我们已经编写了这个特定的预测函数,仅适用于单个输入向量。如果我们想要一次应用此函数到一批输入,在语义上,我们可以简单地编写
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
但是,一次通过一个示例通过网络会很慢!更好的做法是向量化计算,这样在每一层我们都在执行矩阵-矩阵乘法,而不是矩阵-向量乘法。
vmap
函数会为我们进行这种转换。也就是说,如果我们编写
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
那么 vmap
函数将把外循环推入函数内部,并且我们的机器将执行矩阵-矩阵乘法,就像我们手动批处理一样。
手动批处理简单的神经网络没有 vmap
已经足够简单,但在其他情况下,手动向量化可能不切实际或不可能。考虑高效计算每个示例梯度的问题:也就是说,对于一组固定的参数,我们想要计算在批处理中的每个示例分别评估的损失函数的梯度。使用 vmap
很容易
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
当然,vmap
可以任意组合 jit
、grad
和任何其他 JAX 转换!我们在 jax.jacfwd
、jax.jacrev
和 jax.hessian
中使用 vmap
进行前向和反向模式自动微分,以快速计算雅可比和 Hessian 矩阵。
使用 pmap
进行 SPMD 编程
对于多个加速器的并行编程,如多个 GPU,请使用 pmap
。使用 pmap
,您可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。应用 pmap
意味着您编写的函数将由 XLA(类似于 jit
)编译,然后在设备之间并行复制和执行。
以下是在 8-GPU 机器上的示例
from jax import random, pmap
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表达纯映射之外,您还可以在设备之间使用快速 集体通信操作
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
您甚至可以将 pmap 函数嵌套,以实现更复杂的通信模式。
所有这些都可以组合,因此您可以通过并行计算进行自由区分。
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
在反向模式微分 pmap
函数(例如使用 grad
)时,计算的逆传播与正向传播一样并行化。
有关更多信息,请参阅 SPMD 美食手册 和 从头开始构建 SPMD MNIST 分类器的示例。
当前问题
为了更全面地了解当前的一些常见问题,包括示例和解释,我们强烈建议您阅读 常见问题笔记本。以下是一些亮点:
- JAX 变换仅在 纯函数 上工作,这些函数没有副作用并尊重 引用透明性(即使用
is
进行对象身份测试不被保留)。如果您在纯 Python 函数上使用 JAX 变换,您可能会看到如Exception: Can't lift Traced...
或Exception: Different traces at same level
这样的错误。 - 例如,像
x[i] += y
这样的 就地更新数组 不受支持,但 存在功能替代方法。在jit
下,这些功能替代方法将自动就地重用缓冲区。 - 随机数有所不同,但 出于良好原因。
- 如果您正在寻找 卷积算子,它们位于
jax.lax
包中。 - JAX 默认强制使用单精度(32 位,例如
float32
)值,并且要启用双精度(64 位,例如float64
),需要在启动时设置jax_enable_x64
变量(或设置环境变量JAX_ENABLE_X64=True
)。在 TPU 上,JAX 默认使用 32 位值用于所有操作 除 'matmul-like' 操作的内部临时变量之外,例如jax.numpy.dot
和lax.conv
。这些操作具有precision
参数,可以通过三次 bfloat16 传递来近似 32 位操作,这可能会降低运行时间。TPU 上的非 matmul 操作降低到实现,这些实现通常强调速度而牺牲精度,因此在实际操作中,TPU 上的计算精度将低于其他后端上的类似计算。 - NumPy 的某些涉及 Python 标量与 NumPy 类型混合的 dtype 提升语义没有被保留,例如
np.add(1, np.array([2], np.float32)).dtype
是float64
而不是float32
。 - 某些转换,如
jit
,限制了您如何使用 Python 控制流。如果出现错误,您将始终得到响亮的错误。您可能需要使用jit
的static_argnums
参数,结构化控制流原语,如lax.scan
,或者只需在较小的子函数上使用jit
。
安装
支持的平台
Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | 是 | 是 | 是 | 是 | 是 | 是 |
NVIDIA GPU | 是 | 是 | 否 | n/a | 否 | 实验性 |
Google TPU | 是 | n/a | n/a | n/a | n/a | n/a |
AMD GPU | 实验性 | 否 | 否 | n/a | 否 | 否 |
Apple GPU | n/a | 否 | 实验性 | 实验性 | n/a | n/a |
说明
硬件 | 说明 |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | 使用pip安装 -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU | 使用Docker或从源码构建。 |
Apple GPU | 遵循Apple的说明。 |
有关替代安装策略的信息,请参阅文档。这些包括从源码编译、使用Docker安装、使用其他版本的CUDA、社区支持的conda构建以及一些常见问题的答案。
神经网络库
多个谷歌研究小组开发和共享JAX中用于训练神经网络的库。如果您需要一个功能齐全的神经网络训练库,包括示例和指南,请尝试Flax。查看新的NNX API,以获得简化的开发体验。
Google X维护着神经网络库Equinox。这是JAX生态系统中其他几个库的基础。
此外,DeepMind开源了围绕JAX的库生态系统,包括用于梯度处理和优化的Optax、用于RL算法的RLax以及用于可靠代码和测试的chex。(观看NeurIPS 2020 DeepMind关于JAX生态系统的演讲此处)
引用JAX
引用此存储库
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
在上面的bibtex条目中,名称按字母顺序排列,版本号是指jax/version.py中的版本,年份对应于项目的开源发布年份。
一篇关于JAX早期版本(仅支持自动微分和编译到XLA)的论文在SysML 2018上发表。我们目前正在撰写一篇更全面、更及时的论文,涵盖JAX的思想和能力。
参考文档
有关JAX API的详细信息,请参阅参考文档。
有关作为JAX开发者的入门指南,请参阅开发者文档。
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。
源代码分发
构建分发
jax-0.4.34.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db |
|
MD5 | 8f3ff80a4b43a9dd99193e82991a6902 |
|
BLAKE2b-256 | 196acacfcdf77841a4562e555ef35e0dbc5f8ca79c9f1010aaa4cf3973e79c69 |
jax-0.4.34-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e |
|
MD5 | 755b5ca20f3236ae4e5c753f074f776d |
|
BLAKE2b-256 | 06f3c499d358dd7f267a63d7d38ef54aadad82e28d2c28bafff15360c3091946 |