跳转到主要内容

PyTorch LLM微调的本地库

项目描述

torchtune

Unit Test Recipe Integration Test

介绍 | 安装 | 开始使用 | 文档 | 社区 | 许可协议

[!IMPORTANT] 2024年9月25日更新:torchtune 支持了 Llama 3.2 11B VisionLlama 3.2 3BLlama 3.2 1B 模型!请按照我们的安装说明此处进行操作,然后运行此处的任何文本配置或此处的视觉配置。

 

介绍

torchtune 是一个 PyTorch 库,用于轻松编写、微调和实验 LLM。

torchtune 提供以下内容:

  • Llama、Gemma、Mistral、Phi 和 Qwen 模型家族中流行 LLM 的 PyTorch 实现
  • 可黑客攻击的训练食谱,包括全面微调、LoRA、QLoRA、DPO、PPO、QAT、知识蒸馏等
  • 开箱即用的内存效率、性能改进和与最新 PyTorch API 的扩展性
  • YAML 配置,用于轻松配置训练、评估、量化或推理食谱
  • 内置对许多流行数据集格式和提示模板的支持

 

模型

torchtune 目前支持以下模型。

模型 大小
Llama3.2-Vision 11B [模型, 配置]
Llama3.2 1B, 3B [模型, 配置]
Llama3.1 8B, 70B, 405B [模型, 配置]
Llama3 8B, 70B [模型, 配置]
Llama2 7B, 13B, 70B [模型, 配置]
Code-Llama2 7B, 13B, 70B [模型, 配置]
Mistral 7B [模型, 配置]
Gemma 2B, 7B [模型, 配置]
Microsoft Phi3 Mini [模型, 配置]
Qwen2 0.5B, 1.5B, 7B [模型, 配置]

我们一直在添加新的模型,如果您希望看到 torchtune 中有新的模型,请提交问题

 

微调食谱

torchtune 为在一台或多台设备上训练提供了以下微调食谱。

微调方法 设备 食谱 示例配置
全面微调 1-8 full_finetune_single_device
full_finetune_distributed
Llama3.1 8B 单设备
Llama 3.1 70B 分布式
LoRA 微调 1-8 lora_finetune_single_device
lora_finetune_distributed
Qwen2 0.5B 单设备
Gemma 7B 分布式
QLoRA 微调 1-8 lora_finetune_single_device
lora_finetune_distributed
Phi3 Mini 单设备
Llama 3.1 405B 分布式
DoRA/QDoRA 微调 1-8 lora_finetune_single_device
lora_finetune_distributed
Llama3 8B QDoRA 单设备
Llama3 8B DoRA 分布式
量化感知训练 4-8 qat_distributed Llama3 8B QAT
直接偏好优化 1-8 lora_dpo_single_device
lora_dpo_distributed
Llama2 7B 单设备
Llama2 7B 分布式
近端策略优化 1 ppo_full_finetune_single_device Mistral 7B
知识蒸馏 1 knowledge_distillation_single_device Qwen2 1.5B -> 0.5B

上述配置仅作为示例,以帮助您开始。如果您在此处未列出模型,我们可能仍然支持它。如果您不确定是否支持某些内容,请在仓库中提交问题。

 

内存和训练速度

以下为不同 Llama 3.1 模型的内存需求和训练速度示例。

[!NOTE] 为了便于比较,以下所有数值均针对批量大小 2(无梯度累积)、数据集打包到序列长度 2048,且 torch compile 启用。

如果您对在不同硬件或不同模型上运行感兴趣,请查看我们关于内存优化的文档 此处,以找到适合您的正确设置。

模型 微调方法 可运行于 每个 GPU 的峰值内存 每秒 token 数 *
Llama 3.1 8B 完全微调 1x 4090 18.9 GiB 1650
Llama 3.1 8B 完全微调 1x A6000 37.4 GiB 2579
Llama 3.1 8B LoRA 1x 4090 16.2 GiB 3083
Llama 3.1 8B LoRA 1x A6000 30.3 GiB 4699
Llama 3.1 8B QLoRA 1x 4090 7.4 GiB 2413
Llama 3.1 70B 完全微调 8x A100 13.9 GiB ** 1568
Llama 3.1 70B LoRA 8x A100 27.6 GiB 3497
Llama 3.1 405B QLoRA 8x A100 44.8 GB 653

*= 测量一个完整训练周期

**= 使用 CPU 卸载和融合优化器

 

安装

torchtune 与最新的稳定 PyTorch 版本以及预览版 nightly 版本进行了测试。torchtune 利用 torchvision 进行多模态 LLM 的微调,并使用 torchao 进行最新的量化技术;您也应该安装这些。

 

安装稳定版本

# Install stable PyTorch, torchvision, torchao stable releases
pip install torch torchvision torchao
pip install torchtune

 

安装 nightly 版本

# Install PyTorch, torchvision, torchao nightlies
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu

您还可以查看我们的 安装文档,以获取更多信息,包括从源安装 torchtune。

 

要确认已正确安装包,您可以运行以下命令

tune --help

并且应该看到以下输出

usage: tune [-h] {ls,cp,download,run,validate} ...

Welcome to the torchtune CLI!

options:
  -h, --help            show this help message and exit

...

 

开始使用

要开始使用 torchtune,请参阅我们的 第一次微调教程。我们的 端到端工作流程教程 将向您展示如何评估、量化和运行 Llama 模型的推理。本节其余部分将简要介绍这些步骤,以 Llama3.1 为例。

下载模型

按照官方 meta-llama 仓库上的说明,确保您有权访问官方 Llama 模型权重。一旦确认访问权限,您就可以运行以下命令将权重下载到您的本地机器。这还将下载分词器模型和负责任的使用指南。

要下载 Llama3.1,您可以运行

tune download meta-llama/Meta-Llama-3.1-8B-Instruct \
--output-dir /tmp/Meta-Llama-3.1-8B-Instruct \
--hf-token <HF_TOKEN> \

[!Tip] 请设置环境变量 HF_TOKEN 或将 --hf-token 传递给命令以验证您的访问权限。您可以在 https://hugging-face.cn/settings/tokens 找到您的令牌

 

运行微调配方

您可以使用以下命令在单个 GPU 上使用 LoRA 对 Llama3.1 8B 进行微调

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device

对于分布式训练,tune CLI 与 torchrun 集成。要运行 Llama3.1 8B 的完整微调,在两个 GPU 上

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full

[!Tip] 请确保将任何 torchrun 命令放置在配方指定之前。此之后的所有 CLI 参数将覆盖配置,不会影响分布式训练。

 

修改配置

您可以通过两种方式修改配置

配置覆盖

您可以直接从命令行覆盖配置字段

tune run lora_finetune_single_device \
--config llama2/7B_lora_single_device \
batch_size=8 \
enable_activation_checkpointing=True \
max_steps_per_epoch=128

更新本地副本

您也可以将配置复制到您的本地目录,并直接修改内容

tune cp llama3_1/8B_full ./my_custom_config.yaml
Copied to ./my_custom_config.yaml

然后,您可以通过将tune run命令指向您的本地文件来运行您自定义的配方

tune run full_finetune_distributed --config ./my_custom_config.yaml

 

查看tune --help以获取所有可能的CLI命令和选项。有关使用和更新配置的更多信息,请参阅我们的配置深入探讨

 

自定义数据集

torchtune支持在多种不同数据集上微调,包括指令风格聊天风格偏好数据集等。如果您想了解更多关于如何将这些组件应用到您自己的自定义数据集上进行微调的信息,请参阅提供的链接以及我们的API文档

 

社区

torchtune专注于与生态系统中的流行工具和库集成。以下只是几个例子,还有更多正在开发中

 

社区贡献

我们非常重视我们的社区和用户做出的贡献。我们将使用本节来突出一些这些贡献。如果您也想帮忙,请参阅贡献指南

 

致谢

本仓库中的Llama2代码受到了原始Llama2代码的启发。

我们要向EleutherAI、Hugging Face和Weights & Biases致以崇高的敬意,感谢他们作为优秀的合作伙伴,并与我们共同在torchtune中完成了一些集成工作。

我们还想感谢生态系统中的某些出色的库和工具

  • gpt-fast,我们直接采用了其高性能的LLM推理技术
  • llama recipes,为llama2社区提供了起飞的平台
  • bitsandbytes,为PyTorch生态系统带来了多个基于内存和性能的技术
  • @winglianaxolotl,他们在torchtune的设计和功能集合上提供了早期反馈和头脑风暴。
  • lit-gpt,推动LLM微调社区向前发展。
  • HF TRL,使奖励建模对PyTorch社区更加易于访问。

 

许可证

torchtune在BSD 3许可证下发布。然而,您可能还有其他法律义务来规范您对其他内容的用途,例如第三方模型的条款和服务。

项目详情


下载文件

下载您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分发

本发行版没有可用的源代码分发文件。请参阅生成分发存档的教程。

构建分发

torchtune-0.3.1-py3-none-any.whl (596.6 kB 查看散列)

上传时间 Python 3

支持者

AWSAWS 云计算和安全赞助商 DatadogDatadog 监控 FastlyFastly CDN GoogleGoogle 下载分析 MicrosoftMicrosoft PSF 赞助商 PingdomPingdom 监控 SentrySentry 错误日志 StatusPageStatusPage 状态页面