JMP是JAX的混合精度库。
项目描述
Mixed precision training in JAX
安装 | 示例 | 策略 | 损失缩放 | 引用JMP | 参考文献
Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.
This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp
and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).
All code examples below assume the following
import jax
import jax.numpy as jnp
import jmp
half = jnp.float16 # On TPU this should be jnp.bfloat16.
full = jnp.float32
安装
JMP是用纯Python编写的,但通过JAX和NumPy依赖于C++代码。
由于JAX的安装方式取决于您的CUDA版本,JMP在requirements.txt
中未将其列为依赖项。
首先,按照这些说明安装具有相关加速器支持的JAX。
然后,使用pip安装JMP
$ pip install git+https://github.com/deepmind/jmp
示例
您可以在Haiku的完全工作的JMP示例中找到,该示例展示了如何使用混合f32/f16精度将GPU上的训练时间减半,以及混合f32/bf16将TPU上的训练时间减少三分之一。
策略
混合精度策略将混合精度实验中的配置封装起来。
# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
param_dtype=full,
output_dtype=half)
策略对象可以用来转换pytrees
def layer(params, x):
params, x = my_policy.cast_to_compute((params, x))
w, b = params
y = x @ w + b
return my_policy.cast_to_output(y)
params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half
您可以替换给定策略的输出类型
my_policy = my_policy.with_output_dtype(full)
您也可以通过字符串定义策略,这对于指定策略作为命令行参数或作为实验的超参数可能很有用
my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16") # Everything in f16.
half = jmp.get_policy("half") # Everything in half (f16 or bf16).
损失缩放
在以降低精度训练时,请考虑梯度是否需要移动到您所使用的格式可以表示的范围内。当使用float16
训练时,这一点尤为重要,而对于bfloat16
则不太重要。有关更多详细信息,请参阅NVIDIA混合精度用户指南 [1]。
通过损失缩放来移动梯度是最简单的方法,它将您的损失和梯度分别乘以S
和1/S
。
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: jmp.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
params = apply_optimizer(params, grads)
return params
loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
params = train_step(params, loss_scale, ...)
S
的适当值取决于您的模型、损失、批量大小以及可能的其他因素。您可以通过试错来确定这一点。一般来说,您希望选择最大的S
值,这样在反向传播过程中不会出现溢出。NVIDIA [1]建议计算模型(在全精度下)的梯度统计信息,并选择一个使得其与梯度最大范数的乘积小于65,504
的S
。
我们提供了一个动态损失缩放,它会在训练期间定期调整损失缩放,以找到产生有限梯度的最大的S
值。与选择静态损失缩放相比,这更方便、更稳健,但会有轻微的性能影响(介于1%到5%之间)。
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
loss = loss_scale.scale(loss)
return loss
def train_step(params, loss_scale: jmp.LossScale, ...):
grads = jax.grad(my_loss_fn)(...)
grads = loss_scale.unscale(grads)
# You should put gradient clipping etc after unscaling.
# You definitely want to skip non-finite updates with the dynamic loss scale,
# but you might also want to consider skipping them when using a static loss
# scale if you experience NaN's when training.
skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
if skip_nonfinite_updates:
grads_finite = jmp.all_finite(grads)
# Adjust our loss scale depending on whether gradients were finite. The
# loss scale will be periodically increased if gradients remain finite and
# will be decreased if not.
loss_scale = loss_scale.adjust(grads_finite)
# Only apply our optimizer if grads are finite, if any element of any
# gradient is non-finite the whole update is discarded.
params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
else:
# With static or no loss scaling just apply our optimizer.
params = apply_optimizer(params, grads)
# Since our loss scale is dynamic we need to return the new value from
# each step. All loss scales are `PyTree`s.
return params, loss_scale
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
params, loss_scale = train_step(params, loss_scale, ...)
一般来说,使用静态损失缩放应提供最佳速度,但我们已优化动态损失缩放,使其具有竞争力。我们建议您从动态损失缩放开始,如果性能有问题,再切换到静态损失缩放。
我们最终提供了一个无操作的损失缩放,您可以用它来替换。它什么都不做(除了实现jmp.LossScale
API)
loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1
引用JMP
此存储库是DeepMind JAX生态系统的一部分,要引用JMP,请使用DeepMind JAX生态系统引用。
参考文献
[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.
[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvda.net.cn/deeplearning/performance/mixed-precision-training/.
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解有关 安装包 的更多信息。