跳转到主要内容

Flash Attention:快速且内存高效的精确注意力

项目描述

FlashAttention

此存储库提供了以下论文中FlashAttention的官方实现。

FlashAttention:具有I/O感知的快速且内存高效的精确注意力
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
论文:https://arxiv.org/abs/2205.14135
关于我们使用FlashAttention提交给MLPerf 2.0基准测试的文章 https://spectrum.ieee.org/mlperf-rankings-2022FlashAttention

使用方法

我们很高兴看到FlashAttention在其发布后不久就被广泛采用。此 页面 包含了FlashAttention正在使用的部分地点列表。

完整模型代码和训练脚本

我们已发布完整的GPT模型 实现。我们还提供了其他层(例如,MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。总的来说,与Huggingface的基线实现相比,这使训练速度提高了3-5倍,达到每A100高达189 TFLOPs/sec,相当于60.6%的模型FLOPs利用率(我们不需要任何激活检查点)。

我们还包含了一个训练 脚本,用于在Openwebtext上训练GPT2和在The Pile上训练GPT3。

FlashAttention的Triton实现

Phil Tillet(OpenAI)在Triton中有一个FlashAttention的实验性实现: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

由于Triton是高于CUDA的高级语言,因此可能更容易理解和实验。Triton实现中的符号也更接近我们论文中使用的符号。

我们在Triton中也有一个支持注意力偏置(例如ALiBi)的实验性实现: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py

安装和功能

要求

  • CUDA 11.4以上。
  • PyTorch 1.12以上。

我们推荐使用Nvidia的Pytorch容器,该容器包含安装FlashAttention所需的所有工具。

安装方法

pip install flash-attn

如果您看到有关ModuleNotFoundError: No module named 'torch'的错误,这可能是由于pypi的安装隔离。

要修复,您可以运行

pip install flash-attn --no-build-isolation

或者您也可以从源代码编译

python setup.py install

接口:src/flash_attention.py

针对PyTorch标准注意力进行基准测试

PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py

FlashAttention目前支持

  1. Turing、Ampere、Ada或Hopper GPU(例如,H100、A100、RTX 3090、T4、RTX 2080)。
  2. fp16和bf16(bf16需要Ampere、Ada或Hopper GPU)。
  3. 头维度是8的倍数,最多为128(例如,8、16、24、...、128)。头维度>64反向需要A100或H100。

我们的路线图

  1. [Jun 2022] 使包可pip安装[完成,感谢lucidrains]。
  2. [Jun 2022] 支持SM86 GPU(例如,RTX 3080、3090)[完成]。
  3. [Jun 2022] 支持SM75 GPU(例如T4)[完成]。
  4. [Jun 2022] 支持bf16[完成]。
  5. [Jul 2022] 实现交叉注意力[完成]。
  6. [Jul 2022] 支持头维度128[完成]。
  7. [Aug 2022] 合并旋转嵌入[完成]。
  8. [Mar 2023] 支持SM90 GPU(H100)[完成]。

如何使用FlashAttention

这里有一个简单的示例

import torch
from flash_attn.flash_attention import FlashMHA

# Replace this with your correct GPU device
device = "cuda:0"

# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
    embed_dim=128, # total channels (= num_heads * head_dim)
    num_heads=8, # number of heads
    device=device,
    dtype=torch.float16,
)

# Run forward pass with dummy data
x = torch.randn(
    (64, 256, 128), # (batch, seqlen, embed_dim)
    device=device,
    dtype=torch.float16
)

output = flash_mha(x)[0]

或者,您只能导入内部注意力层(这样就不包括输入和输出线性层)

from flash_attn.flash_attention import FlashAttention

# Create the nn.Module
flash_attention = FlashAttention()

或者,如果您需要更细粒度的控制,您可以导入其中一个低级函数(这与torch.nn.functional风格更相似)

from flash_attn.flash_attn_interface import flash_attn_unpadded_func

# or

from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func

# etc.

还有单独的Python文件,包含各种FlashAttention扩展

# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func

# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention

# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func

速度提升和内存节省

我们展示了使用FlashAttention代替PyTorch标准注意力时,根据序列长度和不同GPU的预期速度提升(和内存节省),速度提升取决于内存带宽(我们在较慢的GPU内存上看到更多的速度提升)。

我们目前对这些GPU有基准测试

A100

我们使用这些参数显示FlashAttention的速度提升(类似于BERT-base)

  • 批量大小8
  • 头维度64
  • 12个注意力头

我们的图表显示了128到4096之间的序列长度(当标准注意力在A100上运行出内存时),但FlashAttention可以扩展到序列长度64K。

速度提升

FlashAttention speedup

我们通常在128到4K的序列长度之间看到2-4倍的速度提升,并且由于我们融合了内核,当使用dropout和掩码时,我们看到了更多的速度提升。在语言模型常用的512和1K序列长度上,我们使用dropout和掩码时看到的速度提升可达4倍。

内存

FlashAttention memory

我们在该图表中显示了内存节省(请注意,无论您使用dropout还是掩码,内存占用都是相同的)。内存节省与序列长度成比例--因为标准注意力具有与序列长度平方成正比的内存,而FlashAttention具有与序列长度线性相关的内存。我们在序列长度2K时看到10倍的内存节省,在4K时看到20倍的节省。因此,FlashAttention可以扩展到更长的序列长度。

头维度128

FlashAttention speedup, head dimension 128

我们显示了头维度128的速度提升。这里我们显示批量大小16和12个头的速度提升。由于我们需要在分块中减小块大小,所以与较小的头大小相比,速度提升较小。但是,速度提升仍然显著,特别是当使用因果掩码时。

RTX 3090

对于RTX 3090,我们使用批大小为12,注意力头数为12。内存节省与A100相同,所以我们这里只展示加速效果。

FlashAttention speedup GTX 3090

在GTX 3090上,我们看到了稍微更高的加速效果(2.5-4.5倍),因为GDDR6X的内存带宽低于A100 HBM(约900 GB/s与约1.5 TB/s)。

T4

我们再次使用批大小为12,注意力头数为12。

Flashattention speedup T4

T4 SRAM小于较新的GPU(64 KB),所以我们看到更少的加速效果(我们需要减小块大小,因此最终需要更多的读写操作)。这与第3.2节中我们论文的IO复杂度分析相匹配。

T4 GPU通常用于推理,所以我们也只测量了正向传播的加速效果(请注意,这些结果与上面的图表不可直接比较)

FlashAttention speedup T4 fwd

正向传播的加速效果在2.5倍-4.5倍之间。

测试

我们测试FlashAttention生成的输出和梯度与参考实现相同,达到一定的数值容忍度。具体来说,我们检查FlashAttention的最大数值误差不超过Pytorch基准实现的数值误差的两倍(对于不同的头维度、输入数据类型、序列长度、因果/非因果)。

运行测试

pytest -q -s tests/test_flash_attn.py

当您遇到问题时

FlashAttention的此alpha版本包含为研究项目编写的代码,以验证加快注意力速度的想法。我们已在多个模型(BERT、GPT2、ViT)上进行了测试。然而,实施中可能仍存在一些错误,我们希望在接下来的几个月内解决。

如果您遇到这些错误中的任何一个,请打开相应的GitHub问题!

致谢

我们的实现以Apex的FMHA代码作为起点。

我们感谢Young-Jun Ko对我们关于CUDA问题的深入解释和他对CUDA问题的深思熟虑的回答。

引用

如果您使用此代码库,或认为我们的工作有价值,请引用

@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems},
  year={2022}
}

项目详情


下载文件

下载适用于您平台的项目文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分布

flash_attn_wheels-1.0.9.tar.gz (210.1 kB 查看哈希值)

上传时间 源代码

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面