跳转到主要内容

PyTorch的GPipe

项目描述

这是一个GPipePyTorch中的实现。

from torchgpipe import GPipe

model = nn.Sequential(a, b, c, d)
model = GPipe(model, balance=[1, 1, 1, 1], chunks=8)

for input in data_loader:
    output = model(input)

什么是GPipe?

GPipe是Google Brain发布的一个可扩展的管道并行库,它允许高效地训练大型、内存消耗高的模型。根据论文,GPipe可以使用8倍的设备(TPU)训练出25倍大的模型,并使用4倍的设备将模型训练速度提高3.5倍。

GPipe:使用管道并行高效训练巨型神经网络

Google使用GPipe在AmoebaNet-B上训练了557M参数。这个模型在ImageNet分类基准测试中实现了84.3%的top-1和97.0%的top-5准确率(截至2019年5月的最佳性能)。

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分发

本发布版本没有可用的源代码分发文件。请参阅生成分发存档的教程

构建分发

torchgpipe-0.0.7-py3-none-any.whl (39.1 kB 查看哈希值)

上传时间 Python 3

由以下支持