PyTorch中的LARS实现
项目描述
torchlars
PyTorch中的LARS实现。
from torchlars import LARS
optimizer = LARS(optim.SGD(model.parameters(), lr=0.1))
什么是LARS?
LARS(层自适应速率缩放)是一种用于大型批量训练的优化算法,由You、Gitman和Ginsburg提出,它在每个优化步骤计算每个层的本地学习率。根据论文,当使用LARS在ImageNet ILSVRC(2016)分类任务上训练ResNet-50时,即使批量大小增加到32K,学习曲线和最佳top-1准确率仍然与基线(批量大小为256且未使用LARS的训练)相似。
最初,LARS是以SGD优化器来表述的,论文中没有提及将其扩展到其他优化器。相比之下,torchlars
将LARS实现为一个包装器,可以接受任何优化器,包括SGD作为基础。
此外,torchlars中的LARS在考虑CUDA环境中的操作方面比现有实现更加全面。因此,在CPU到GPU同步不会发生的环境中,与仅使用SGD相比,您只能看到微小的速度损失。
使用方法
目前,torchlars需要以下环境
- Linux
- Python 3.6+
- PyTorch 1.1+
- CUDA 10+
要使用torchlars,请通过PyPI进行安装
$ pip install torchlars
要使用LARS,只需将您的基优化器包装在torchlars.LARS
中。LARS继承自torch.optim.Optimizer
,因此您可以将LARS直接用作代码中的优化器。然后,当您调用LARS的step方法时,LARS会自动在运行基优化器(如SGD或Adam)之前计算局部学习率。
以下示例代码展示了如何使用SGD作为基优化器来使用LARS。
from torchlars import LARS
base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
基准测试
在ImageNet分类中使用ResNet-50
批量大小 | 学习率策略 | 学习率 | 预热 | 轮次 | 最佳Top-1准确率,% |
---|---|---|---|---|---|
256 | 多项式(2) | 0.2 | 不适用 | 90 | 73.79 |
8k | LARS+多项式(2) | 12.8 | 5 | 90 | 73.78 |
16K | LARS+多项式(2) | 25.0 | 5 | 90 | 73.36 |
32K | LARS+多项式(2) | 29.0 | 5 | 90 | 72.26 |
上图和表格显示了在ResNet-50上重现的性能基准,如论文中的表4和图5所述。
蓝色线代表基线结果,即批量大小为256的训练结果,其他代表8K、16K、32K的训练结果。如您所见,每个结果都显示出相似的学习曲线和最佳Top-1准确率。
大多数实验条件与论文中使用的一致,但我们略微更改了一些条件(如学习率),以观察LARS论文中提出的可比结果。
注意:我们参考了论文提供的日志文件以获得上述超参数。
作者和许可
torchlars项目由Chunmyong Park在Kakao Brain开发,并得到了Heungsub Lee、Myungryong Jeong、Woonhyuk Baek和Chiheon Kim的帮助。该项目在Apache License 2.0下发行。
引用
如果您将此库应用于任何项目和研究,请引用我们的代码
@misc{torchlars,
author = {Park, Chunmyong and Lee, Heungsub and Jeong, Myungryong and
Baek, Woonhyuk and Kim, Chiheon},
title = {torchlars, {A} {LARS} implementation in {PyTorch}},
howpublished = {\url{https://github.com/kakaobrain/torchlars}},
year = {2019}
}
项目详情
torchlars-0.1.2.tar.gz的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | efef03da466de95b34c736e6d19469478b98f74105572f5a816949dc104fe299 |
|
MD5 | c43fa383faf23082fc69ff4ac8a4e201 |
|
BLAKE2b-256 | f512633c1822dc87d72ad2a80ba40706c7a77056c68d6211351313ff0e96bda0 |