将ao技术应用于GPU模型的应用程序
项目描述
torchao:PyTorch架构优化
介绍 | 推理 | 训练 | 可组合性 | 自定义内核 | Alpha功能 | 安装 | 集成 | 视频 | 许可
介绍
torchao:PyTorch库,用于自定义数据类型和优化。量化并稀疏化权重、梯度、优化器和激活函数,以用于推理和训练。
来自为您带来快速系列团队
torchao默认与Huggingface上的大多数PyTorch模型配合使用torch.compile()
和FSDP2
。
推理
训练后量化
对模型进行量化和稀疏化只需要一行代码,它应该适用于任何包含nn.Linear
的模型,包括您喜欢的HuggingFace模型。更全面的用法说明请参阅这里,稀疏化这里,以及HuggingFace推理示例这里
对于推理,我们有以下选项:
- 仅量化权重:最适合内存受限的模型
- 量化权重和激活:最适合计算受限的模型
- 量化激活和权重,并稀疏化权重
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int4_weight_only,
int8_weight_only
)
quantize_(m, int4_weight_only())
对于gpt-fast,在bs=1时,int4_weight_only()
是最佳选项,因为它比torch.compiled基线提高了2x
的tok/s,并将VRAM需求减少了约65%
。
如果您在使用这些技术时遇到减速,或者不确定使用哪个选项,请考虑使用autoquant,它将自动分析层并选择量化每层的最佳方式。
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
我们还提供了一组面向开发者的API,以便您可以实现自己的量化算法,请以HQQ算法作为激励示例。
KV缓存量化
我们添加了kv缓存量化和其他功能,以启用长上下文长度的推理(并必然是内存高效的)。
实际上,这些功能以及int4权重仅量化使我们能够将峰值内存减少约55%
,这意味着我们可以使用130k
的上下文长度在只有18.9 GB
的峰值内存的情况下进行Llama3.1-8B推理。更多详细信息请参阅这里
量化感知训练
训练后量化可以生成快速紧凑的模型,但可能会导致精度下降。我们建议探索量化感知训练(QAT)以克服这一限制。与Torchtune合作,我们开发了一个QAT配方,与传统PTQ相比,在Llama3上实现了显著的精度提升,在hellaswag上恢复了96%
的精度下降,在wikitext上恢复了68%
的困惑度下降。我们提供了一个完整的配方这里
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# Insert "fake quantize" operations into linear layers.
# These operations simulate quantization numerics
model = qat_quantizer.prepare(model)
# Run Training...
# Convert fake quantize to actual quantize operations
model = qat_quantizer.convert(model)
训练
Float8
torchao.float8实现了使用缩放浮点8 dtypes的训练配方,如https://arxiv.org/abs/2209.05433所述。
启用torch.compile
时,当前结果显示吞吐量加快了最多1.5x
(在128个H100 GPU上的LLaMa 3 70B预训练作业中)详情
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m, module_filter_fn=...)
有关使用浮点8进行预训练的最小化训练配方,请参阅torchtitan
稀疏训练
我们已添加对半结构化2:4稀疏性的支持,在ViT-L上实现了6%
的端到端速度提升。完整博客这里
代码更改是一行代码,完整示例请参阅这里
swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
内存高效优化器
ADAM的内存消耗是模型参数的两倍,因此我们可以将优化器状态量化为8位或4位,有效地将优化器VRAM需求降低了2倍或4倍。
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
实际上,我们比精心编写的内核慢一点,但这些优化器的实现仅用几行PyTorch代码编写,并进行了编译,因此请使用它们或复制粘贴它们以用于您的量化优化器。基准测试这里
我们同样支持 单GPU CPU卸载,其中梯度(与权重大小相同)和优化器将被高效地发送到CPU。仅此一项就可以 将您的VRAM需求降低60%
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
可组合性
torch.compile
:我们设计的关键原则是可组合性,因为我们提供的任何新的数据类型或布局都需要与我们的编译器兼容。无论是使用纯PyTorch、CUDA、C++还是Triton编写的内核,事情都应该按预期工作!因此,我们在纯PyTorch中编写数据类型、布局或位打包逻辑,并生成高效的内核。- FSDP2:历史上大多数量化都是用于推理,现在有一个研究热点是将分布式算法和量化相结合。
将低比特数据类型与编译和fsdp的可组合性结合的最佳例子是 NF4,我们用它实现了 QLoRA 算法。所以如果你在这个领域的交叉研究,我们很乐意听取你的意见。
自定义内核
我们增加了对编写和发布 自定义操作符 的支持,这些操作符不与 torch.compile()
冲突,所以如果你喜欢编写内核但讨厌打包它们,以便在所有操作系统和CUDA版本上都能正常工作,我们很乐意接受你对自定义操作符的贡献。我们有几个示例可以参考
- fp6,它提供了2倍于fp16的推理速度,并具有易于使用的API
quantize_(model, fp6_llm_weight_only())
- 2:4 Sparse Marlin GEMM,对于FP16xINT4内核,即使批处理大小高达256,也能实现2倍的加速
- int4 tinygemm unpacker,它使得切换量化后端进行推理变得更加容易
如果你认为我们应该更仔细地查看其他CUDA内核,请在此 问题 上留言
Alpha功能
我们非常兴奋但需要更多时间来完善的事情
- MX 训练和推理支持,使用 OCP MX规范 数据类型,这些数据类型可以描述为组内缩放的float8/float6/float4/int8,缩放因子被限制为2的幂。这项工作处于原型阶段,因为硬件支持尚未可用。
- Int8量化训练:我们正在尝试进行完整的int8训练。使用
quantize_(model, int8_weight_only_quantized_training())
很容易使用。这项工作处于原型阶段,因为内存基准测试还不够令人信服。 - IntX:我们通过在纯PyTorch中进行一些巧妙的位打包并编译它来支持所有的整数。这项工作处于原型阶段,因为不幸的是,没有更多的投资于编译器或低比特内核,int4比任何更小的数据类型更有吸引力。
- Bitnet:这对团队中的大多数人来说都非常酷。这项工作处于原型阶段,因为这些内核的有效性高度依赖于更好的硬件和内核支持。
安装
torchao
充分利用了PyTorch中的几个新功能,建议与当前的夜间版本或最新的稳定版本一起使用。
从PyPI的稳定版本,默认为CUDA 12.1
pip install torchao
从PyTorch索引的稳定版本
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
夜间版本
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
对于大多数开发者来说,你可能希望跳过构建自定义C++/CUDA扩展以实现更快迭代。
USE_CPP=0 pip install -e .
集成
我们还幸运地集成到一些领先的开源库中,包括
- Hugging Face transformers,内置推理后端和低比特优化器
- Hugging Face diffusers,多亏了Sayak Paul的简化示例
- Mobius HQQ后端利用我们的int4内核在4090上实现了195 tok/s
视频
许可证
torchao
以BSD 3许可证发布。
项目详情
哈希值 for torchao-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 1a14c6d5c6e8a1b03eded529bec9306f271ce59de43fa2e4699fd83f464bb5cd |
|
MD5 | d61a86268ed576e722c96374698844ed |
|
BLAKE2b-256 | 8cf75e8fd7eaab81f0ceefd7caf58647d92684977abca7a38bce6e6891f86a6e |
哈希值 for torchao-0.5.0-cp312-cp312-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | d1dacb8b899d76ea97f166d421c16140016cebed2090f04b17eee7e15b69969a |
|
MD5 | c05c6f93807f5d46b1446e7bca8f3b54 |
|
BLAKE2b-256 | 91d087470c59148ed296418a3533953803f7dec1084acb1ad43e1ac9a110a1cd |
哈希值 for torchao-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 2aede6d89481ccda6bfd81f1666707765244de97697cb42b57c1001d9f928492 |
|
MD5 | 4b5633ad78a069086b585ac7b37366d7 |
|
BLAKE2b-256 | 13cb94636b2639d0a227d130863e87192da7e8caee5463cf7c678740f83c1d48 |
哈希值 for torchao-0.5.0-cp311-cp311-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 3b8be5c3dcb641688397501d55cab4804688c3d27bdb9f8e0abcba1f2810678e |
|
MD5 | cd82297e3bf64bc171615dd1231268da |
|
BLAKE2b-256 | b445e7dfdabe99427db6417ba96e0e6bd0234ed2cf713fe1b1526292c11871b3 |
哈希值 for torchao-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 30a4c5c6ef7e3f5fa9a8dae3e2b9bb82c34d7c61a55f008e120303e22dd82cb6 |
|
MD5 | 44de4ba9237216635ccdc53e45c72352 |
|
BLAKE2b-256 | ff0b5d0bf43aed2548a6788ac4ca434d46652e3976f582671eab28496dc77ca5 |
哈希值 for torchao-0.5.0-cp310-cp310-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 6daff53790532d48e6b023bdd34030e8f87075f75f4206f3dd6577e6d99d7132 |
|
MD5 | 7d53201f66181059f3ec03aa02e2a9de |
|
BLAKE2b-256 | 5315f6ffe100392b4a1dc63fdb5787cfe4ff228dd3dddb57fec70e2f5547d384 |
哈希值 for torchao-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 575974f5aea245cdb9ecc30b6d7385de043c3ac874f5b7701ceb5362e521b7f1 |
|
MD5 | 54fec02a380eeeef7b4a459e16709d39 |
|
BLAKE2b-256 | c3cfe89b4c7c1627885ab64f55e0ed1097d58853a720488aa88aaa201bd32f3b |
哈希值 for torchao-0.5.0-cp39-cp39-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | f3ed3c5a55609c0051d0e1b12a609f416d896afd2d6a5a1ac73ee14c0230801c |
|
MD5 | 02872ee0753ff238913bafc3ce5d801a |
|
BLAKE2b-256 | e0abb7e5ea3921175d54aa840b694b06d2f5ddfc3ac617e1e0a4e7d10d1ecb3d |