跳转到主要内容

使用可检查点数据加载来增强您的LLM训练

项目描述

Epochraft

Python GitHub license Checks status Tests status pypi

简介

Epochraft 是一个针对LLM简化训练优化的数据加载库,具有从云存储中**流式传输**、**即时分词**和**迭代器检查点**的功能。该名称来源于“epoch”和“craft”的结合。

从云存储中流式传输

在本地磁盘上存储为预训练LLM所需的大量数据集可能会令人望而却步。即使在可行的情况下,在训练之前转移数据可能会很繁琐且耗时。

Epochraft提供了一系列存储解决方案,包括S3、GCS、Azure Blob Storage、HDFS、WebHDFS、HTTP、HTTPS、SFTP以及本地文件系统(由smart-open提供)。其显著特性之一是在下载数据的同时进行训练。由于其基于流的架构,无法进行完整的数据洗牌。然而,Epochraft通过同时访问多个数据分片、混合传入数据,并在预定的缓冲区大小内进行额外的洗牌,实现了某种程度的数据洗牌。

此外,它还支持 Python 的顺序或可迭代接口。例如,它可以利用 Hugging Face Datasets。虽然使用 Epochraft 与如此小的数据集似乎没有多少好处,但这使得可以共用相同的代码库进行 SFT 和预训练。

即时分词

一些先前框架需要预分词。这意味着必须先对训练数据进行分词,然后再在预训练前存储。这很麻烦。训练必须等到这一步完成才能开始。此外,如果数据集或分词器发生变化,您必须再次重复此步骤。此外,还需要管理分词数据的责任。

现在,您可能想知道,“即时分词不会太慢吗?” 答案是绝对的不会。

例如,Llama2-7B 的训练速度约为每 GPU 3K 个标记/秒(如在 表 2 中所示)。Llama2 的分词器可以以近 1M 个标记/秒的速率处理,使用单个 CPU 进程。这意味着即使在实时分词的情况下,GPU 也可以充分利用,没有瓶颈。对于更大的模型,情况变得更加有利。对于 13B 模型,1.5K 个标记/秒的速率足以饱和每个 GPU,而对于 70B 模型,只需 300 个标记/秒。

数据加载器检查点

除了模型和优化器的状态字典之外,我们是否应该考虑保存数据加载器的状态字典呢?

在训练 ResNets 90 个周期是常态的时候,这不是一个问题。每个周期的末尾的检查点就足够了。然而,在当前的 LLM 时代,训练通常围绕一个周期进行。

当仅训练 1 个周期时,确保数据加载器可以从周期的中间继续变得至关重要。在恢复训练时,必须处理到中断点之前尚未使用的所有数据。鉴于数据的庞大,高效的恢复机制是必不可少的。

快速开始

安装

pip install epochraft

示例

这是一个构建典型预训练数据集的示例。我们将很快添加其他示例,例如 SFT。

from epochraft import CheckpointableDataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

# `{00..99}` will be expanded (see `braceexpand`)
url = "s3://.../cc-100/cc-100_{00..99}.jsonl"

train_dataset = (
    CheckpointableDataset
    .from_files(url, repeat=True, shuffle_shards=True)
    .tokenize(tokenizer)        # Tokenize the texts
    .ensure_bos_eos(tokenizer)  # Add BOS and EOS tokens where necessary
    .concat_chunk(1024)         # Concatenate and chunk the tokens into a fixed length of 1024 tokens
    .shuffle(1000)              # Shuffle the sequences using a buffer of size 1000
    .batch(8)                   # Group the data into mini-batches with a batch size of 8
)

for batch in train_dataset:
    input_ids = batch["input_ids"]  # Input data for this iteration (torch.Tensor)

    # Implement the training iteration using `input_ids` here
    ...

检查点

通常,您会获取并保存模型和优化器的 state_dict。除此之外,请还获取并保存迭代器的 state_dict

train_iter = train_dataset.iter()  # Same meaning as `iter(train_dataset)`

for batch in train_iter:
    step = batch["step"]
    ...

    if step % ckpt_freq == 0:
        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "iter": train_iter.state_dict(),
        }
        torch.save(state_dict, ckpt_path)

恢复

您可以通过将 state_dict 传递给 CheckpointableDataset 实例的 iter 方法来恢复迭代器的状态。

state_dict = torch.load(ckpt_path)
train_iter = train_dataset.iter(state_dict=state_dict["iter"])

开发

pip install -e .[development]
mypy .; black .; flake8 .; isort .
pytest tests

项目详情


下载文件

下载适合您平台的应用程序。如果您不确定选择哪一个,请了解有关 安装包 的更多信息。

源分发

epochraft-0.1.0.dev20231107.tar.gz (29.7 kB 查看哈希值)

上传时间 源代码

构建发行版

epochraft-0.1.0.dev20231107-py3-none-any.whl (40.1 kB 查看哈希值)

上传时间 Python 3

由以下机构支持

AWSAWS云计算和安全赞助商DatadogDatadog监控FastlyFastlyCDNGoogleGoogle下载分析MicrosoftMicrosoftPSF赞助商PingdomPingdom监控SentrySentry错误日志StatusPageStatusPage状态页面