跳转到主要内容

微分、编译和转换Numpy代码。

项目描述

logo

可变换的规模数值计算

Continuous integration PyPI version

快速入门 | 变换 | 安装指南 | 神经网络库 | 变更日志 | 参考文档

JAX是什么?

JAX是一个Python库,用于面向加速器的数组计算和程序转换,旨在进行高性能数值计算和大规模机器学习。

凭借其更新的Autograd版本,JAX可以自动微分原生的Python和NumPy函数。它可以微分循环、分支、递归和闭包,并且可以求出多级导数。它支持通过grad进行的逆模式微分(即反向传播)以及前向模式微分,并且两者可以任意组合到任何级别。

新功能是JAX使用XLA来编译和运行您的NumPy程序在GPU和TPU上。默认情况下,编译在幕后进行,库调用得到即时编译和执行。但是JAX还允许您使用一个函数API jit将您的Python函数即时编译成XLA优化的内核。编译和自动微分可以任意组合,因此您可以在不离开Python的情况下表达复杂的算法并获得最佳性能。您甚至可以使用pmap同时编程多个GPU或TPU核心,并对整个过程进行微分。

深入了解,你会发现JAX实际上是一个可扩展的系统,用于可组合函数转换gradjit都是这种转换的实例。其他包括vmap用于自动向量化以及pmap用于多加速器的单程序多数据(SPMD)并行编程,未来还有更多。

这是一个研究项目,不是官方的Google产品。请期待可能出现错误和尖锐边缘。请通过试用、报告错误以及告诉我们您的想法来帮助我们!

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jit(grad(loss))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

内容

快速入门:云中的Colab

直接使用浏览器中的笔记本开始,连接到Google Cloud GPU。以下是一些入门笔记本

JAX现在可以在Cloud TPUs上运行。要试用预览,请参阅Cloud TPU Colabs

深入了解JAX

变换

在核心上,JAX是一个用于转换数值函数的可扩展系统。以下四个转换是主要感兴趣的:gradjitvmappmap

使用grad进行自动微分

JAX与Autograd大致具有相同的API。最流行的函数是grad,用于逆模式梯度

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

您可以使用grad进行任意阶的微分。

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673

对于更高级的自动微分,您可以使用 jax.vjp 进行逆模式向量-雅可比乘积,以及 jax.jvp 进行前模式雅可比-向量乘积。这两个可以任意组合,也可以与其他 JAX 转换一起使用。以下是一种组合方法,用于高效地计算 完整的 Hessian 矩阵

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

Autograd 一样,您可以使用 Python 控制结构进行微分

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

请参阅有关自动微分的 参考文档JAX 自动微分烹饪书 以获取更多信息。

使用 jit 进行编译

您可以使用 XLA 使用 jit 编译您的函数,jit 可以用作 @jit 装饰器或作为高阶函数。

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

您可以根据喜好混合 jitgrad 以及任何其他 JAX 转换。

使用 jit 对函数可以使用的 Python 控制流类型施加约束;请参阅 常见错误笔记本 以获取更多信息。

使用 vmap 进行自动向量化

vmap 是向量化的映射。它具有将函数映射到数组轴的熟悉语义,但与将循环保留在外部不同,它将循环推入函数的基本操作以获得更好的性能。

使用 vmap 可以让您在代码中省略批量维度。例如,考虑这个简单的 未批处理 神经网络预测函数

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` on the right-hand side!
    activations = jnp.tanh(outputs)        # inputs to the next layer
  return outputs                           # no activation on last layer

我们通常编写 jnp.dot(activations, W) 以允许在 activations 的左侧具有批量维度,但我们已经编写了这个特定的预测函数,仅适用于单个输入向量。如果我们想要一次应用此函数到一批输入,在语义上,我们可以简单地编写

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

但是,一次通过一个示例通过网络会很慢!更好的做法是向量化计算,这样在每一层我们都在执行矩阵-矩阵乘法,而不是矩阵-向量乘法。

vmap 函数会为我们进行这种转换。也就是说,如果我们编写

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

那么 vmap 函数将把外循环推入函数内部,并且我们的机器将执行矩阵-矩阵乘法,就像我们手动批处理一样。

手动批处理简单的神经网络没有 vmap 已经足够简单,但在其他情况下,手动向量化可能不切实际或不可能。考虑高效计算每个示例梯度的问题:也就是说,对于一组固定的参数,我们想要计算在批处理中的每个示例分别评估的损失函数的梯度。使用 vmap 很容易

per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

当然,vmap 可以任意组合 jitgrad 和任何其他 JAX 转换!我们在 jax.jacfwdjax.jacrevjax.hessian 中使用 vmap 进行前向和反向模式自动微分,以快速计算雅可比和 Hessian 矩阵。

使用 pmap 进行 SPMD 编程

对于多个加速器的并行编程,如多个 GPU,请使用 pmap。使用 pmap,您可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。应用 pmap 意味着您编写的函数将由 XLA(类似于 jit)编译,然后在设备之间并行复制和执行。

以下是在 8-GPU 机器上的示例

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.key(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

除了表达纯映射之外,您还可以在设备之间使用快速 集体通信操作

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# prints [0.         0.16666667 0.33333334 0.5       ]

您甚至可以将 pmap 函数嵌套,以实现更复杂的通信模式。

所有这些都可以组合,因此您可以通过并行计算进行自由区分。

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

在反向模式微分 pmap 函数(例如使用 grad)时,计算的逆传播与正向传播一样并行化。

有关更多信息,请参阅 SPMD 美食手册从头开始构建 SPMD MNIST 分类器的示例

当前问题

为了更全面地了解当前的一些常见问题,包括示例和解释,我们强烈建议您阅读 常见问题笔记本。以下是一些亮点:

  1. JAX 变换仅在 纯函数 上工作,这些函数没有副作用并尊重 引用透明性(即使用 is 进行对象身份测试不被保留)。如果您在纯 Python 函数上使用 JAX 变换,您可能会看到如 Exception: Can't lift Traced...Exception: Different traces at same level 这样的错误。
  2. 例如,像 x[i] += y 这样的 就地更新数组 不受支持,但 存在功能替代方法。在 jit 下,这些功能替代方法将自动就地重用缓冲区。
  3. 随机数有所不同,但 出于良好原因
  4. 如果您正在寻找 卷积算子,它们位于 jax.lax 包中。
  5. JAX 默认强制使用单精度(32 位,例如 float32)值,并且要启用双精度(64 位,例如 float64),需要在启动时设置 jax_enable_x64 变量(或设置环境变量 JAX_ENABLE_X64=True)。在 TPU 上,JAX 默认使用 32 位值用于所有操作 'matmul-like' 操作的内部临时变量之外,例如 jax.numpy.dotlax.conv。这些操作具有 precision 参数,可以通过三次 bfloat16 传递来近似 32 位操作,这可能会降低运行时间。TPU 上的非 matmul 操作降低到实现,这些实现通常强调速度而牺牲精度,因此在实际操作中,TPU 上的计算精度将低于其他后端上的类似计算。
  6. NumPy 的某些涉及 Python 标量与 NumPy 类型混合的 dtype 提升语义没有被保留,例如 np.add(1, np.array([2], np.float32)).dtypefloat64 而不是 float32
  7. 某些转换,如 jit限制了您如何使用 Python 控制流。如果出现错误,您将始终得到响亮的错误。您可能需要使用 jitstatic_argnums 参数结构化控制流原语,如 lax.scan,或者只需在较小的子函数上使用 jit

安装

支持的平台

Linux x86_64 Linux aarch64 Mac x86_64 Mac ARM Windows x86_64 Windows WSL2 x86_64
CPU
NVIDIA GPU n/a 实验性
Google TPU n/a n/a n/a n/a n/a
AMD GPU 实验性 n/a
Apple GPU n/a 实验性 实验性 n/a n/a

说明

硬件 说明
CPU pip install -U jax
NVIDIA GPU pip install -U "jax[cuda12]"
Google TPU 使用pip安装 -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
AMD GPU 使用Docker从源码构建
Apple GPU 遵循Apple的说明

有关替代安装策略的信息,请参阅文档。这些包括从源码编译、使用Docker安装、使用其他版本的CUDA、社区支持的conda构建以及一些常见问题的答案。

神经网络库

多个谷歌研究小组开发和共享JAX中用于训练神经网络的库。如果您需要一个功能齐全的神经网络训练库,包括示例和指南,请尝试Flax。查看新的NNX API,以获得简化的开发体验。

Google X维护着神经网络库Equinox。这是JAX生态系统中其他几个库的基础。

此外,DeepMind开源了围绕JAX的库生态系统,包括用于梯度处理和优化的Optax、用于RL算法的RLax以及用于可靠代码和测试的chex。(观看NeurIPS 2020 DeepMind关于JAX生态系统的演讲此处

引用JAX

引用此存储库

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/jax-ml/jax},
  version = {0.3.13},
  year = {2018},
}

在上面的bibtex条目中,名称按字母顺序排列,版本号是指jax/version.py中的版本,年份对应于项目的开源发布年份。

一篇关于JAX早期版本(仅支持自动微分和编译到XLA)的论文在SysML 2018上发表。我们目前正在撰写一篇更全面、更及时的论文,涵盖JAX的思想和能力。

参考文档

有关JAX API的详细信息,请参阅参考文档

有关作为JAX开发者的入门指南,请参阅开发者文档

项目详情


发布历史 发布通知 | RSS源

下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源代码分发

jax-0.4.34.tar.gz (1.8 MB 查看哈希值)

上传时间 源代码

构建分发

jax-0.4.34-py3-none-any.whl (2.1 MB 查看哈希值)

上传时间 Python 3

支持者