Transformer加速库
项目描述
Transformer Engine
最新动态
[03/2024] Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8
[03/2024] FP8 Training Support in SageMaker Model Parallelism Library
[12/2023] New NVIDIA NeMo Framework Features and NVIDIA H200
[11/2023] Inflection-2: The Next Step Up
[11/2023] Unleashing The Power Of Transformers With NVIDIA Transformer Engine
[09/2023] Transformer Engine added to AWS DL Container for PyTorch Training
[06/2023] Breaking MLPerf Training Records with NVIDIA H100 GPUs
[04/2023] Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1)
什么是Transformer Engine?
Transformer Engine (TE) 是一个用于在 NVIDIA GPU 上加速 Transformer 模型的库,包括在 Hopper GPU 上使用 8 位浮点 (FP8) 精度,以在训练和推理过程中提供更好的性能和更低的内存利用率。TE 提供了一组针对流行 Transformer 架构的高度优化的构建块,以及一个类似于自动混合精度 API,可以与您特定的框架代码无缝使用。TE 还包括一个与框架无关的 C++ API,可以与其他深度学习库集成,以支持 Transformer 的 FP8。
随着 Transformer 模型参数数量的持续增长,BERT、GPT 和 T5 等架构的训练和推理变得非常内存和计算密集。大多数深度学习框架默认使用 FP32 进行训练。然而,对于许多深度学习模型来说,这并非必要,以实现完全的精度。使用混合精度训练,即在训练模型时将单精度 (FP32) 与较低精度(例如 FP16)格式结合,与 FP32 训练相比,可以实现显著的加速,同时精度差异很小。Hopper GPU 架构引入了 FP8 精度,与 FP16 相比,性能有所提高,而精度没有下降。尽管所有主要深度学习框架都支持 FP16,但今天框架中并没有原生支持 FP8。
TE 通过提供与流行大型语言模型 (LLM) 库集成的 API 来解决 FP8 支持的问题。它提供了一个 Python API,包括构建 Transformer 层所需的模块,以及包括用于 FP8 支持所需的 struct 和内核的 C++ 库。TE 内部提供的模块维护了用于 FP8 训练所需的缩放因子和其他值,大大简化了用户的混合精度训练。
亮点
支持 FP8 的 Transformer 层构建的易于使用的模块
Transformer 模型的优化(例如,融合内核)
支持 NVIDIA Hopper 和 NVIDIA Ada GPU 上的 FP8
支持在 NVIDIA Ampere GPU 架构的各代和后续版本上跨所有精度(FP16、BF16)的优化
示例
PyTorch
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp)
loss = out.sum()
loss.backward()
JAX
Flax
import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp)
return jnp.mean(out)
# Initialize models.
variables = model.init(init_rng, inp)
other_variables, params = flax.core.pop(variables, 'params')
# Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
安装
先决条件
Linux x86_64
Hopper 的 CUDA 12.0+ 和 Ada 的 CUDA 12.1+
支持 CUDA 12.0 或更高版本的 NVIDIA 驱动程序
cuDNN 8.1 或更高版本
对于融合注意力,CUDA 12.1 或更高版本,支持 CUDA 12.1 或更高版本的 NVIDIA 驱动程序,以及 cuDNN 8.9 或更高版本。
Docker
通过在 NVIDIA GPU Cloud (NGC) 目录 上使用 Docker 图像,是开始使用 Transformer Engine 的最快方式。例如,要交互式地使用 NGC PyTorch 容器,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
其中 23.10 是容器版本。例如,10 月 2023 年发布的 23.10。
pip
要安装 Transformer Engine 的最新稳定版本,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
这将自动检测是否安装了任何受支持的深度学习框架,并构建 Transformer Engine 的支持。要显式指定框架,设置环境变量 NVTE_FRAMEWORK 为逗号分隔的列表(例如,NVTE_FRAMEWORK=jax,pytorch)。
从源代码
使用 FlashAttention-2 编译
Transformer Engine 版本 v0.11.0 为 PyTorch 添加了对 FlashAttention-2 的支持,以改善性能。
已知问题是 FlashAttention-2 编译资源密集且需要大量 RAM(请参阅 错误),这可能导致在安装 Transformer Engine 时出现内存不足错误。请尝试在环境中设置 MAX_JOBS=1 以绕过问题。
请注意,NGC PyTorch 23.08+ 容器包括 FlashAttention-2。
破坏性变化
v1.7:PyTorch 的填充掩码定义
为了统一Transformer Engine中三个框架对注意力掩码的定义和用法,填充掩码已从表示包含对应位置在注意力中的True改为表示排除该位置在PyTorch实现中的False。从v1.7版本开始,所有注意力掩码类型都遵循相同的定义,其中True表示屏蔽对应位置,而False表示在注意力计算中包含该位置。
此变化的例子如下:
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]
FP8 收敛
FP8已在不同的模型架构和配置下进行了广泛测试,我们发现FP8和BF16训练损失曲线之间没有显著差异。FP8在下游LLM任务(例如LAMBADA和WikiText)的准确性方面也得到了验证。以下是不同框架下测试收敛的模型示例。
模型 |
框架 |
来源 |
---|---|---|
T5-770M |
JAX/T5x |
|
MPT-1.3B |
Mosaic Composer |
|
GPT-5B |
JAX/Paxml |
https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
GPT-5B |
NeMo Framework |
可请求 |
LLama2-7B |
阿里巴巴Pai |
|
T5-11B |
JAX/T5x |
可请求 |
MPT-13B |
Mosaic Composer |
https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8 |
GPT-22B |
NeMo Framework |
可请求 |
LLama2-70B |
阿里巴巴Pai |
|
GPT-175B |
JAX/Paxml |
https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
集成
Transformer Engine已集成到流行的LLM框架中,例如
贡献
我们欢迎为Transformer Engine做出贡献!要为Transformer Engine做出贡献并提交pull请求,请遵循CONTRIBUTING.rst指南中概述的准则。
论文
视频
项目详情
散列值 for transformer_engine_cu12-1.10.0-py3-none-manylinux_2_28_x86_64.whl
算法 | 散列摘要 | |
---|---|---|
SHA256 | dd5d16585ca7feb1e47d56c0417d28b3288a34eb31bd9ef076c39c2c302f3405 |
|
MD5 | c7ca6821da330271b211d65fc8e18ca7 |
|
BLAKE2b-256 | b0b472d6b7ffaaf56ee30dd89dc7b8fdc44a438e1756bc667b7a8083191d1ab7 |
散列值 for transformer_engine_cu12-1.10.0-py3-none-manylinux_2_28_aarch64.whl
算法 | 散列摘要 | |
---|---|---|
SHA256 | ea9f006787071d24f3aecd362e1e056377bcf93515caa8f8e868dd92ee466bb6 |
|
MD5 | 36fb82e09c8f1f5ebff159fcb861515f |
|
BLAKE2b-256 | 7406df0a3d7970fda57fee49e84d8ececaca648f3c1509ae19766a2e98314b49 |