跳转到主要内容

深度学习最小求解器

项目描述

Flashy

tests badge linter badge docs badge

动机

我们注意到,在我们的所有研究项目中,我们都重复使用了相同的结构。PyTorch-Lightning过于复杂,由于复杂性,它不允许达到相同的可修改性。Flashy旨在成为一种替代品。我们不声称它将适合所有用例,我们的首要目标是让它适合我们。我们希望保持代码足够简单,以至于您可以简单地继承和覆盖行为,或者甚至将您想要的复制粘贴到您的项目中。

定义

Flashy的核心是求解器。求解器只负责两件事

  • 将指标记录到多个后端(文件日志、tensorboard或WanDB),并使用自定义格式,
  • 检查点和自动跟踪求解器的有状态部分。

除了这些核心功能之外,Flashy还提供分布式训练工具,特别是替代DistributedDataParallel的方案,它可以与复杂的流程冲突,以及围绕DataLoader的简单包装以支持分布式训练。

Flashy基于epoch,这可能会让一些人觉得过时。不要将epochs视为对数据集的单次遍历,而应将其视为工作流程管理的时间原子单位。每个epochs的结束由对flashy.BaseSolver.commit(save_checkpoint=True)的调用标记。

每个epochs由多个阶段组成,例如trainvalidtest等,并且每次不一定相同。阶段是一种便利性,有助于自动报告带有适当元数据的指标。

依赖项和安装

Flashy 假设使用 PyTorch 同时使用 Dora。您可以在对 flashy/state.py 进行一些修改后,不使用 PyTorch 来使用它。Dora 在几个地方内置,应该不难移除,尽管我们强烈推荐使用它。Flashy 至少需要 Python 3.8。

要安装 Flashy,请运行以下命令

# For the moment we recommend having bleeding edge versions of Dora and Submitit
pip install -U git+https://github.com/facebookincubator/submitit@main#egg=submitit
pip install -U git+https://git@github.com/facebookresearch/dora#egg=dora-search
# Now let's install Flashy!
pip install git+ssh://git@github.com/facebookresearch/flashy.git#egg=flashy

要为开发安装 Flashy,您可以克隆此存储库并运行

make install

入门指南

我们将假设您正在使用 Hydra。您需要熟悉 Dora。让我们构建一个非常基础的名为 basic 的项目,其结构如下

basic/
  conf/
    config.yaml
  train.py
  __init__.py

此项目位于 examples 文件夹中。对于 config.yaml,我们可以从基本开始

epochs: 10
lr: 0.1

dora:
  # Output folder for all the artifacts of an experiment.
  dir: /tmp/flashy_basic_${oc.env:USER}/outputs

__init__.py 只是空的。 train.py 包含大部分逻辑

import torch
from dora import hydra_main
import flashy


class Solver(flashy.BaseSolver):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.model = torch.nn.Linear(32, 1)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=cfg.lr)
        self.best_state = {}
        # register_stateful supports any attribute. On checkpoints loading,
        # it will try to use inplace method when possible (i.e. Modules, lists, dicts).
        self.register_stateful('model', 'optim', 'best_state')
        self.init_tensorboard()  # all metrics will be reported to stderr and tensorboard.

    def run(self):
        self.restore()  # load checkpoint
        for epoch in range(self.epoch, self.cfg.epochs):
            # Stages are used for automatic metric reporting to Dora, and it also
            # allows tuning how metrics are formatted.
            self.run_stage('train', self.train)
            # Commit will send the metrics to Dora and save checkpoints by default.
            self.commit(save_checkpoint=epoch % 2 == 1)

    def train(self):
        # this is super dumb, checkout `examples/cifar/solver.py` for more advance usage!
        x = torch.randn(4, 32)
        y = self.model(x)
        loss = y.abs().mean()
        loss.backward()
        self.optim.step()
        self.optim.zero_grad()
        return {'loss': loss.item()}


@hydra_main(config_path='config', config_name='config', version_base='1.1')
def main(cfg):
    # Setup logging both to XP specific folder, and to stderr.
    flashy.setup_logging()
    # Initialize distributed training, no need to specify anything when using Dora.
    flashy.distrib.init()
    solver = Solver(cfg)
    solver.run()


if __name__ == '__main__':
    main()

从包含 basic 的文件夹中,您可以使用以下命令启动训练

dora -P basic run
dora run  # if no other package contains a train.py file in the current folder.

示例

有关更高级的示例,请参阅 examples/cifar/solver.py,其中包含真实训练和分布式。当从 examples/ 文件夹运行示例时,您必须将您要运行的包传递给 Dora,因为有多种可能性

dora -P [basic|cifar] run

API

请查阅 Flashy API 文档

许可证

Flashy 在 MIT 许可证下提供,可在存储库根目录下的 LICENSE 文件中找到。 flashy.loggers.utils 的一些部分改编自 PyTorch-Lightning,最初在 Apache 2.0 许可证下,有关详细信息,请参阅 flashy/loggers/utils.py

项目详情


下载文件

下载适合您平台的文件。如果您不确定要选择哪个,请了解更多关于 安装软件包 的信息。

源分发

flashy-0.0.2.tar.gz (72.4 kB 查看哈希)

上传时间

由以下支持

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面