跳转到主要内容

PyTorch中的LARS实现

项目描述

torchlars

PyPI Build Status

PyTorch中的LARS实现。

from torchlars import LARS
optimizer = LARS(optim.SGD(model.parameters(), lr=0.1))

什么是LARS?

LARS(层自适应速率缩放)是一种用于大型批量训练的优化算法,由You、Gitman和Ginsburg提出,它在每个优化步骤计算每个层的本地学习率。根据论文,当使用LARS在ImageNet ILSVRC(2016)分类任务上训练ResNet-50时,即使批量大小增加到32K,学习曲线和最佳top-1准确率仍然与基线(批量大小为256且未使用LARS的训练)相似。

卷积神经网络的批量训练

最初,LARS是以SGD优化器来表述的,论文中没有提及将其扩展到其他优化器。相比之下,torchlars将LARS实现为一个包装器,可以接受任何优化器,包括SGD作为基础。

此外,torchlars中的LARS在考虑CUDA环境中的操作方面比现有实现更加全面。因此,在CPU到GPU同步不会发生的环境中,与仅使用SGD相比,您只能看到微小的速度损失。

使用方法

目前,torchlars需要以下环境

  • Linux
  • Python 3.6+
  • PyTorch 1.1+
  • CUDA 10+

要使用torchlars,请通过PyPI进行安装

$ pip install torchlars

要使用LARS,只需将您的基优化器包装在torchlars.LARS中。LARS继承自torch.optim.Optimizer,因此您可以将LARS直接用作代码中的优化器。然后,当您调用LARS的step方法时,LARS会自动在运行基优化器(如SGD或Adam)之前计算局部学习率。

以下示例代码展示了如何使用SGD作为基优化器来使用LARS。

from torchlars import LARS

base_optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)

output = model(input)
loss = loss_fn(output, target)
loss.backward()

optimizer.step()

基准测试

在ImageNet分类中使用ResNet-50

批量大小 学习率策略 学习率 预热 轮次 最佳Top-1准确率,%
256 多项式(2) 0.2 不适用 90 73.79
8k LARS+多项式(2) 12.8 5 90 73.78
16K LARS+多项式(2) 25.0 5 90 73.36
32K LARS+多项式(2) 29.0 5 90 72.26

上图和表格显示了在ResNet-50上重现的性能基准,如论文中的表4和图5所述。

蓝色线代表基线结果,即批量大小为256的训练结果,其他代表8K、16K、32K的训练结果。如您所见,每个结果都显示出相似的学习曲线和最佳Top-1准确率。

大多数实验条件与论文中使用的一致,但我们略微更改了一些条件(如学习率),以观察LARS论文中提出的可比结果。

注意:我们参考了论文提供的日志文件以获得上述超参数。

作者和许可

torchlars项目由Chunmyong ParkKakao Brain开发,并得到了Heungsub LeeMyungryong JeongWoonhyuk BaekChiheon Kim的帮助。该项目在Apache License 2.0下发行。

引用

如果您将此库应用于任何项目和研究,请引用我们的代码

@misc{torchlars,
  author       = {Park, Chunmyong and Lee, Heungsub and Jeong, Myungryong and
                  Baek, Woonhyuk and Kim, Chiheon},
  title        = {torchlars, {A} {LARS} implementation in {PyTorch}},
  howpublished = {\url{https://github.com/kakaobrain/torchlars}},
  year         = {2019}
}

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

torchlars-0.1.2.tar.gz (6.5 kB 查看哈希)

上传时间 源代码

由以下机构支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面