使用可检查点数据加载来增强您的LLM训练
项目描述
Epochraft
简介
Epochraft 是一个针对LLM简化训练优化的数据加载库,具有从云存储中**流式传输**、**即时分词**和**迭代器检查点**的功能。该名称来源于“epoch”和“craft”的结合。
从云存储中流式传输
在本地磁盘上存储为预训练LLM所需的大量数据集可能会令人望而却步。即使在可行的情况下,在训练之前转移数据可能会很繁琐且耗时。
Epochraft提供了一系列存储解决方案,包括S3、GCS、Azure Blob Storage、HDFS、WebHDFS、HTTP、HTTPS、SFTP以及本地文件系统(由smart-open提供)。其显著特性之一是在下载数据的同时进行训练。由于其基于流的架构,无法进行完整的数据洗牌。然而,Epochraft通过同时访问多个数据分片、混合传入数据,并在预定的缓冲区大小内进行额外的洗牌,实现了某种程度的数据洗牌。
此外,它还支持 Python 的顺序或可迭代接口。例如,它可以利用 Hugging Face Datasets。虽然使用 Epochraft 与如此小的数据集似乎没有多少好处,但这使得可以共用相同的代码库进行 SFT 和预训练。
即时分词
一些先前框架需要预分词。这意味着必须先对训练数据进行分词,然后再在预训练前存储。这很麻烦。训练必须等到这一步完成才能开始。此外,如果数据集或分词器发生变化,您必须再次重复此步骤。此外,还需要管理分词数据的责任。
现在,您可能想知道,“即时分词不会太慢吗?” 答案是绝对的不会。
例如,Llama2-7B 的训练速度约为每 GPU 3K 个标记/秒(如在 表 2 中所示)。Llama2 的分词器可以以近 1M 个标记/秒的速率处理,使用单个 CPU 进程。这意味着即使在实时分词的情况下,GPU 也可以充分利用,没有瓶颈。对于更大的模型,情况变得更加有利。对于 13B 模型,1.5K 个标记/秒的速率足以饱和每个 GPU,而对于 70B 模型,只需 300 个标记/秒。
数据加载器检查点
除了模型和优化器的状态字典之外,我们是否应该考虑保存数据加载器的状态字典呢?
在训练 ResNets 90 个周期是常态的时候,这不是一个问题。每个周期的末尾的检查点就足够了。然而,在当前的 LLM 时代,训练通常围绕一个周期进行。
当仅训练 1 个周期时,确保数据加载器可以从周期的中间继续变得至关重要。在恢复训练时,必须处理到中断点之前尚未使用的所有数据。鉴于数据的庞大,高效的恢复机制是必不可少的。
快速开始
安装
pip install epochraft
示例
这是一个构建典型预训练数据集的示例。我们将很快添加其他示例,例如 SFT。
from epochraft import CheckpointableDataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
# `{00..99}` will be expanded (see `braceexpand`)
url = "s3://.../cc-100/cc-100_{00..99}.jsonl"
train_dataset = (
CheckpointableDataset
.from_files(url, repeat=True, shuffle_shards=True)
.tokenize(tokenizer) # Tokenize the texts
.ensure_bos_eos(tokenizer) # Add BOS and EOS tokens where necessary
.concat_chunk(1024) # Concatenate and chunk the tokens into a fixed length of 1024 tokens
.shuffle(1000) # Shuffle the sequences using a buffer of size 1000
.batch(8) # Group the data into mini-batches with a batch size of 8
)
for batch in train_dataset:
input_ids = batch["input_ids"] # Input data for this iteration (torch.Tensor)
# Implement the training iteration using `input_ids` here
...
检查点
通常,您会获取并保存模型和优化器的 state_dict
。除此之外,请还获取并保存迭代器的 state_dict
。
train_iter = train_dataset.iter() # Same meaning as `iter(train_dataset)`
for batch in train_iter:
step = batch["step"]
...
if step % ckpt_freq == 0:
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"iter": train_iter.state_dict(),
}
torch.save(state_dict, ckpt_path)
恢复
您可以通过将 state_dict
传递给 CheckpointableDataset
实例的 iter 方法来恢复迭代器的状态。
state_dict = torch.load(ckpt_path)
train_iter = train_dataset.iter(state_dict=state_dict["iter"])
开发
pip install -e .[development]
mypy .; black .; flake8 .; isort .
pytest tests
项目详情
下载文件
下载适合您平台的应用程序。如果您不确定选择哪一个,请了解有关 安装包 的更多信息。
源分发
构建发行版
epochraft-0.1.0.dev20231107.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | ddf8e4f96661b851c66daf3960dd6f919afc441c6288e7830f48013fe2b02a51 |
|
MD5 | a0bb49f6f2c39cf89b6e52e2a10f8dfc |
|
BLAKE2b-256 | 78917dcc54220bf6b6f1d9a3f007541268a0de715f0816184910aadb18e9bdac |
epochraft-0.1.0.dev20231107-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | c007e78581e0b1931b6a7354a6befd533cd3f701e8ffa8ba9f5a1c104316f1d6 |
|
MD5 | ffeebe545e246f25f96e7ffa884769ef |
|
BLAKE2b-256 | 2244a1c5958f609bd10d8728a138570bfa8eda0c90f1c3d6d857191a0464d065 |