Flash Attention:快速且内存高效的精确注意力
项目描述
FlashAttention
此存储库提供了以下论文中FlashAttention的官方实现。
FlashAttention:具有I/O感知的快速且内存高效的精确注意力
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
论文:https://arxiv.org/abs/2205.14135
关于我们使用FlashAttention提交给MLPerf 2.0基准测试的文章 https://spectrum.ieee.org/mlperf-rankings-2022。
使用方法
我们很高兴看到FlashAttention在其发布后不久就被广泛采用。此 页面 包含了FlashAttention正在使用的部分地点列表。
完整模型代码和训练脚本
我们已发布完整的GPT模型 实现。我们还提供了其他层(例如,MLP、LayerNorm、交叉熵损失、旋转嵌入)的优化实现。总的来说,与Huggingface的基线实现相比,这使训练速度提高了3-5倍,达到每A100高达189 TFLOPs/sec,相当于60.6%的模型FLOPs利用率(我们不需要任何激活检查点)。
我们还包含了一个训练 脚本,用于在Openwebtext上训练GPT2和在The Pile上训练GPT3。
FlashAttention的Triton实现
Phil Tillet(OpenAI)在Triton中有一个FlashAttention的实验性实现: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
由于Triton是高于CUDA的高级语言,因此可能更容易理解和实验。Triton实现中的符号也更接近我们论文中使用的符号。
我们在Triton中也有一个支持注意力偏置(例如ALiBi)的实验性实现: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
安装和功能
要求
- CUDA 11.4以上。
- PyTorch 1.12以上。
我们推荐使用Nvidia的Pytorch容器,该容器包含安装FlashAttention所需的所有工具。
安装方法
pip install flash-attn
如果您看到有关ModuleNotFoundError: No module named 'torch'
的错误,这可能是由于pypi的安装隔离。
要修复,您可以运行
pip install flash-attn --no-build-isolation
或者您也可以从源代码编译
python setup.py install
接口:src/flash_attention.py
针对PyTorch标准注意力进行基准测试
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention目前支持
- Turing、Ampere、Ada或Hopper GPU(例如,H100、A100、RTX 3090、T4、RTX 2080)。
- fp16和bf16(bf16需要Ampere、Ada或Hopper GPU)。
- 头维度是8的倍数,最多为128(例如,8、16、24、...、128)。头维度>64反向需要A100或H100。
我们的路线图
[Jun 2022] 使包可pip安装[完成,感谢lucidrains]。[Jun 2022] 支持SM86 GPU(例如,RTX 3080、3090)[完成]。[Jun 2022] 支持SM75 GPU(例如T4)[完成]。[Jun 2022] 支持bf16[完成]。[Jul 2022] 实现交叉注意力[完成]。[Jul 2022] 支持头维度128[完成]。[Aug 2022] 合并旋转嵌入[完成]。[Mar 2023] 支持SM90 GPU(H100)[完成]。
如何使用FlashAttention
这里有一个简单的示例
import torch
from flash_attn.flash_attention import FlashMHA
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
或者,您只能导入内部注意力层(这样就不包括输入和输出线性层)
from flash_attn.flash_attention import FlashAttention
# Create the nn.Module
flash_attention = FlashAttention()
或者,如果您需要更细粒度的控制,您可以导入其中一个低级函数(这与torch.nn.functional
风格更相似)
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# or
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
# etc.
还有单独的Python文件,包含各种FlashAttention扩展
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func
# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
速度提升和内存节省
我们展示了使用FlashAttention代替PyTorch标准注意力时,根据序列长度和不同GPU的预期速度提升(和内存节省),速度提升取决于内存带宽(我们在较慢的GPU内存上看到更多的速度提升)。
我们目前对这些GPU有基准测试
A100
我们使用这些参数显示FlashAttention的速度提升(类似于BERT-base)
- 批量大小8
- 头维度64
- 12个注意力头
我们的图表显示了128到4096之间的序列长度(当标准注意力在A100上运行出内存时),但FlashAttention可以扩展到序列长度64K。
速度提升
我们通常在128到4K的序列长度之间看到2-4倍的速度提升,并且由于我们融合了内核,当使用dropout和掩码时,我们看到了更多的速度提升。在语言模型常用的512和1K序列长度上,我们使用dropout和掩码时看到的速度提升可达4倍。
内存
我们在该图表中显示了内存节省(请注意,无论您使用dropout还是掩码,内存占用都是相同的)。内存节省与序列长度成比例--因为标准注意力具有与序列长度平方成正比的内存,而FlashAttention具有与序列长度线性相关的内存。我们在序列长度2K时看到10倍的内存节省,在4K时看到20倍的节省。因此,FlashAttention可以扩展到更长的序列长度。
头维度128
我们显示了头维度128的速度提升。这里我们显示批量大小16和12个头的速度提升。由于我们需要在分块中减小块大小,所以与较小的头大小相比,速度提升较小。但是,速度提升仍然显著,特别是当使用因果掩码时。
RTX 3090
对于RTX 3090,我们使用批大小为12,注意力头数为12。内存节省与A100相同,所以我们这里只展示加速效果。
在GTX 3090上,我们看到了稍微更高的加速效果(2.5-4.5倍),因为GDDR6X的内存带宽低于A100 HBM(约900 GB/s与约1.5 TB/s)。
T4
我们再次使用批大小为12,注意力头数为12。
T4 SRAM小于较新的GPU(64 KB),所以我们看到更少的加速效果(我们需要减小块大小,因此最终需要更多的读写操作)。这与第3.2节中我们论文的IO复杂度分析相匹配。
T4 GPU通常用于推理,所以我们也只测量了正向传播的加速效果(请注意,这些结果与上面的图表不可直接比较)
正向传播的加速效果在2.5倍-4.5倍之间。
测试
我们测试FlashAttention生成的输出和梯度与参考实现相同,达到一定的数值容忍度。具体来说,我们检查FlashAttention的最大数值误差不超过Pytorch基准实现的数值误差的两倍(对于不同的头维度、输入数据类型、序列长度、因果/非因果)。
运行测试
pytest -q -s tests/test_flash_attn.py
当您遇到问题时
FlashAttention的此alpha版本包含为研究项目编写的代码,以验证加快注意力速度的想法。我们已在多个模型(BERT、GPT2、ViT)上进行了测试。然而,实施中可能仍存在一些错误,我们希望在接下来的几个月内解决。
如果您遇到这些错误中的任何一个,请打开相应的GitHub问题!
致谢
我们的实现以Apex的FMHA代码作为起点。
我们感谢Young-Jun Ko对我们关于CUDA问题的深入解释和他对CUDA问题的深思熟虑的回答。
引用
如果您使用此代码库,或认为我们的工作有价值,请引用
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
项目详情
flash_attn_wheels-1.0.9.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 6f4ed8e9c2ea8dedd085b50ce377e7a90310ef26bb0880afd789beece502598c |
|
MD5 | b54cdae2066074f1b362b0790e6f99f6 |
|
BLAKE2b-256 | f92e0ff829365bcffb2a215f1d5d95016b073c2dc90e359333c4a3e3b60d6698 |