跳转到主要内容

Transformer加速库

项目描述

License

Transformer Engine

快速入门 | 安装 | 用户指南 | 示例 | FP8 收敛 | 集成 | 发布说明

最新动态

H200

什么是Transformer Engine?

Transformer Engine (TE) 是一个用于在 NVIDIA GPU 上加速 Transformer 模型的库,包括在 Hopper GPU 上使用 8 位浮点 (FP8) 精度,以在训练和推理过程中提供更好的性能和更低的内存利用率。TE 提供了一组针对流行 Transformer 架构的高度优化的构建块,以及一个类似于自动混合精度 API,可以与您特定的框架代码无缝使用。TE 还包括一个与框架无关的 C++ API,可以与其他深度学习库集成,以支持 Transformer 的 FP8。

随着 Transformer 模型参数数量的持续增长,BERT、GPT 和 T5 等架构的训练和推理变得非常内存和计算密集。大多数深度学习框架默认使用 FP32 进行训练。然而,对于许多深度学习模型来说,这并非必要,以实现完全的精度。使用混合精度训练,即在训练模型时将单精度 (FP32) 与较低精度(例如 FP16)格式结合,与 FP32 训练相比,可以实现显著的加速,同时精度差异很小。Hopper GPU 架构引入了 FP8 精度,与 FP16 相比,性能有所提高,而精度没有下降。尽管所有主要深度学习框架都支持 FP16,但今天框架中并没有原生支持 FP8。

TE 通过提供与流行大型语言模型 (LLM) 库集成的 API 来解决 FP8 支持的问题。它提供了一个 Python API,包括构建 Transformer 层所需的模块,以及包括用于 FP8 支持所需的 struct 和内核的 C++ 库。TE 内部提供的模块维护了用于 FP8 训练所需的缩放因子和其他值,大大简化了用户的混合精度训练。

亮点

  • 支持 FP8 的 Transformer 层构建的易于使用的模块

  • Transformer 模型的优化(例如,融合内核)

  • 支持 NVIDIA Hopper 和 NVIDIA Ada GPU 上的 FP8

  • 支持在 NVIDIA Ampere GPU 架构的各代和后续版本上跨所有精度(FP16、BF16)的优化

示例

PyTorch

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

JAX

Flax

import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)

    def loss_fn(params, other_vars, inp):
      out = model.apply({'params':params, **other_vars}, inp)
      return jnp.mean(out)

    # Initialize models.
    variables = model.init(init_rng, inp)
    other_variables, params = flax.core.pop(variables, 'params')

    # Construct the forward and backward function
    fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

    for _ in range(10):
      loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

安装

先决条件

  • Linux x86_64

  • Hopper 的 CUDA 12.0+ 和 Ada 的 CUDA 12.1+

  • 支持 CUDA 12.0 或更高版本的 NVIDIA 驱动程序

  • cuDNN 8.1 或更高版本

  • 对于融合注意力,CUDA 12.1 或更高版本,支持 CUDA 12.1 或更高版本的 NVIDIA 驱动程序,以及 cuDNN 8.9 或更高版本。

Docker

通过在 NVIDIA GPU Cloud (NGC) 目录 上使用 Docker 图像,是开始使用 Transformer Engine 的最快方式。例如,要交互式地使用 NGC PyTorch 容器,

docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3

其中 23.10 是容器版本。例如,10 月 2023 年发布的 23.10。

pip

要安装 Transformer Engine 的最新稳定版本,

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

这将自动检测是否安装了任何受支持的深度学习框架,并构建 Transformer Engine 的支持。要显式指定框架,设置环境变量 NVTE_FRAMEWORK 为逗号分隔的列表(例如,NVTE_FRAMEWORK=jax,pytorch)。

从源代码

请参阅安装指南.

使用 FlashAttention-2 编译

Transformer Engine 版本 v0.11.0 为 PyTorch 添加了对 FlashAttention-2 的支持,以改善性能。

已知问题是 FlashAttention-2 编译资源密集且需要大量 RAM(请参阅 错误),这可能导致在安装 Transformer Engine 时出现内存不足错误。请尝试在环境中设置 MAX_JOBS=1 以绕过问题。

请注意,NGC PyTorch 23.08+ 容器包括 FlashAttention-2。

破坏性变化

v1.7:PyTorch 的填充掩码定义

为了统一Transformer Engine中三个框架对注意力掩码的定义和用法,填充掩码已从表示包含对应位置在注意力中的True改为表示排除该位置在PyTorch实现中的False。从v1.7版本开始,所有注意力掩码类型都遵循相同的定义,其中True表示屏蔽对应位置,而False表示在注意力计算中包含该位置。

此变化的例子如下:

# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
 b, b, 0, 0, 0,
 c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True,  True,  True, False, False,
  True,  True, False, False, False,
  True,  True,  True,  True, False]
# and for v1.7 onwards it should be,
[False, False, False,  True,  True,
 False, False,  True,  True,  True,
 False, False, False, False,  True]

FP8 收敛

FP8已在不同的模型架构和配置下进行了广泛测试,我们发现FP8和BF16训练损失曲线之间没有显著差异。FP8在下游LLM任务(例如LAMBADA和WikiText)的准确性方面也得到了验证。以下是不同框架下测试收敛的模型示例。

模型

框架

来源

T5-770M

JAX/T5x

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance

MPT-1.3B

Mosaic Composer

https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1

GPT-5B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results

GPT-5B

NeMo Framework

可请求

LLama2-7B

阿里巴巴Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ

T5-11B

JAX/T5x

可请求

MPT-13B

Mosaic Composer

https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8

GPT-22B

NeMo Framework

可请求

LLama2-70B

阿里巴巴Pai

https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ

GPT-175B

JAX/Paxml

https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results

集成

Transformer Engine已集成到流行的LLM框架中,例如

贡献

我们欢迎为Transformer Engine做出贡献!要为Transformer Engine做出贡献并提交pull请求,请遵循CONTRIBUTING.rst指南中概述的准则。

论文

视频

项目详情


下载文件

下载适用于您平台的应用程序文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

此版本没有可用的源分布文件。请参阅有关生成分发存档的教程。

构建分布

transformer_engine_cu12-1.10.0-py3-none-manylinux_2_28_x86_64.whl (119.3 MB 查看哈希值)

上传时间 Python 3 manylinux: glibc 2.28+ x86_64

transformer_engine_cu12-1.10.0-py3-none-manylinux_2_28_aarch64.whl (119.2 MB 查看哈希值)

上传于 Python 3 manylinux: glibc 2.28+ ARM64

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面