深度学习最小求解器
项目描述
Flashy
动机
我们注意到,在我们的所有研究项目中,我们都重复使用了相同的结构。PyTorch-Lightning过于复杂,由于复杂性,它不允许达到相同的可修改性。Flashy旨在成为一种替代品。我们不声称它将适合所有用例,我们的首要目标是让它适合我们。我们希望保持代码足够简单,以至于您可以简单地继承和覆盖行为,或者甚至将您想要的复制粘贴到您的项目中。
定义
Flashy的核心是求解器。求解器只负责两件事
- 将指标记录到多个后端(文件日志、tensorboard或WanDB),并使用自定义格式,
- 检查点和自动跟踪求解器的有状态部分。
除了这些核心功能之外,Flashy还提供分布式训练工具,特别是替代DistributedDataParallel的方案,它可以与复杂的流程冲突,以及围绕DataLoader的简单包装以支持分布式训练。
Flashy基于epoch,这可能会让一些人觉得过时。不要将epochs视为对数据集的单次遍历,而应将其视为工作流程管理的时间原子单位。每个epochs的结束由对flashy.BaseSolver.commit(save_checkpoint=True)
的调用标记。
每个epochs由多个阶段组成,例如train
、valid
、test
等,并且每次不一定相同。阶段是一种便利性,有助于自动报告带有适当元数据的指标。
依赖项和安装
Flashy 假设使用 PyTorch 同时使用 Dora。您可以在对 flashy/state.py
进行一些修改后,不使用 PyTorch 来使用它。Dora 在几个地方内置,应该不难移除,尽管我们强烈推荐使用它。Flashy 至少需要 Python 3.8。
要安装 Flashy,请运行以下命令
# For the moment we recommend having bleeding edge versions of Dora and Submitit
pip install -U git+https://github.com/facebookincubator/submitit@main#egg=submitit
pip install -U git+https://git@github.com/facebookresearch/dora#egg=dora-search
# Now let's install Flashy!
pip install git+ssh://git@github.com/facebookresearch/flashy.git#egg=flashy
要为开发安装 Flashy,您可以克隆此存储库并运行
make install
入门指南
我们将假设您正在使用 Hydra。您需要熟悉 Dora。让我们构建一个非常基础的名为 basic
的项目,其结构如下
basic/
conf/
config.yaml
train.py
__init__.py
此项目位于 examples 文件夹中。对于 config.yaml,我们可以从基本开始
epochs: 10
lr: 0.1
dora:
# Output folder for all the artifacts of an experiment.
dir: /tmp/flashy_basic_${oc.env:USER}/outputs
__init__.py
只是空的。 train.py 包含大部分逻辑
import torch
from dora import hydra_main
import flashy
class Solver(flashy.BaseSolver):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = torch.nn.Linear(32, 1)
self.optim = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
self.best_state = {}
# register_stateful supports any attribute. On checkpoints loading,
# it will try to use inplace method when possible (i.e. Modules, lists, dicts).
self.register_stateful('model', 'optim', 'best_state')
self.init_tensorboard() # all metrics will be reported to stderr and tensorboard.
def run(self):
self.restore() # load checkpoint
for epoch in range(self.epoch, self.cfg.epochs):
# Stages are used for automatic metric reporting to Dora, and it also
# allows tuning how metrics are formatted.
self.run_stage('train', self.train)
# Commit will send the metrics to Dora and save checkpoints by default.
self.commit(save_checkpoint=epoch % 2 == 1)
def train(self):
# this is super dumb, checkout `examples/cifar/solver.py` for more advance usage!
x = torch.randn(4, 32)
y = self.model(x)
loss = y.abs().mean()
loss.backward()
self.optim.step()
self.optim.zero_grad()
return {'loss': loss.item()}
@hydra_main(config_path='config', config_name='config', version_base='1.1')
def main(cfg):
# Setup logging both to XP specific folder, and to stderr.
flashy.setup_logging()
# Initialize distributed training, no need to specify anything when using Dora.
flashy.distrib.init()
solver = Solver(cfg)
solver.run()
if __name__ == '__main__':
main()
从包含 basic
的文件夹中,您可以使用以下命令启动训练
dora -P basic run
dora run # if no other package contains a train.py file in the current folder.
示例
有关更高级的示例,请参阅 examples/cifar/solver.py,其中包含真实训练和分布式。当从 examples/
文件夹运行示例时,您必须将您要运行的包传递给 Dora,因为有多种可能性
dora -P [basic|cifar] run
API
请查阅 Flashy API 文档
许可证
Flashy 在 MIT 许可证下提供,可在存储库根目录下的 LICENSE 文件中找到。 flashy.loggers.utils
的一些部分改编自 PyTorch-Lightning,最初在 Apache 2.0 许可证下,有关详细信息,请参阅 flashy/loggers/utils.py