跳转到主要内容

MLgym,一个用于在研究中进行分布式和可重复的机器学习模型训练的Python框架。

项目描述


一个功能丰富的深度学习框架,提供实验的全可重复性。

CircleCI

可重复性是深度学习(研究)中一个常见问题,模型通常在Jupyter笔记本中实现,或者每个新项目从头开始实现整个训练和评估流程。实验设置的缺乏标准化和重复的样板代码阻碍了可重复性。

MLgym旨在通过将实验设置与代码分离,并提供模型训练、模型评估、实验日志记录、检查点以及实验分析的整个基础设施来提高可重复性。

具体来说,MLgym提供了一套可扩展的机器学习组件(例如,训练器、评估器、损失函数等)。该框架根据在配置文件中指定的参数动态实例化这些组件,该配置文件描述了整个实验设置(即训练和评估流程)。实验设置与代码的分离最大化了ML实验的复现性和可解释性。机器学习组件显著减少了实施工作,让您能够专注于自己的想法。

此外,MLgym还提供以下关键特性

  • 组件注册以注册自定义组件及其依赖项。

  • 热启动,允许在崩溃后恢复训练

  • 可自定义的检查点策略

  • MLboard webservice通过订阅WebSocket日志环境进行实验跟踪/分析(在线和离线)

  • 支持大规模、多GPU训练,支持网格搜索、嵌套交叉验证和交叉验证

  • 通过WebSocket和事件源进行分布式日志记录,允许位置无关的日志记录和完全可重复性

  • 在配置文件中定义训练和评估流程,实现实验设置与代码的分离。

请注意,目前此代码应被视为实验性的,并且尚未准备好投入生产。

安装

安装 MLgym 有两种选择,最简单的方法是从 pip 仓库安装框架。

pip install mlgym

对于最新版本,可以直接通过在根目录下运行 cd 并执行以下操作来从源安装:

pip install src/

使用方法

我们提供了一个易于使用的示例,让您可以运行 MLgym 的 实验设置

在运行实验之前,我们需要设置 MLboard 记录环境,即 WebSocket 服务和 RESTful Web 服务。MLgym 通过 WebSocket API 记录训练/评估进度和评估结果,允许 MLboard 前端接收实时更新。RESTful Web 服务提供接收检查点和实验设置的端点。有关这两个 API 的完整规范,请参阅此处

我们分别在端口 5001 和 5002 上启动 WebSocket 服务和 RESTful Web 服务。如果需要,可以自由选择不同的端口。同样,我们将 event_storage 文件夹指定为本地事件存储文件夹。请注意,要从不同的端口访问 WebSocket 服务,我们需要指定允许的 CORS 原因。在这个例子中,我们只使用 MLboard 前端从 127.0.0.1:8080 本地使用 WebSocket 服务。

ml_board_ws_endpoint --host 127.0.0.1 --port 5002 --event_storage_path event_storage --cors_allowed_origins http://127.0.0.1:8080

ml_board_rest_endpoint --port 5001 --event_storage_path event_storage

接下来,我们运行实验设置。我们 cd 到示例文件夹,并使用带有相应路径的参数 gs_config_path 运行 run.py。参数 process_count 指定要并行运行的实验数量。num_epochs 限制训练模型的最大轮数。如果模型性能在一段时间内没有显著提高,则 gs_config.yml 中定义的检查点策略将提前停止训练。

cd mlgym/example/grid_search_example

python run.py --process_count 3 \
              --text_logging_path general_logging/ \
              --num_epochs 10 \
              --websocket_logging_servers http://127.0.0.1:5002 \
              --gs_rest_api_endpoint http://127.0.0.1:5001 \
              train \
              --gs_config_path gs_config.yml

要可视化实时更新,我们运行 MLboard 前端。我们指定提供前端和 REST Web 服务以及 WebSocket 服务的端点的服务器主机和端口。参数 run_id 指的是我们想要分析的实验运行,具体取决于您的案例。每个实验运行都存储在 event_storage 路径下的单独文件夹中。文件夹名称指的是相应的实验运行 ID。

ml_board --ml_board_host 127.0.0.1 --ml_board_port 8080 --rest_endpoint http://127.0.0.1:5001 --ws_endpoint http://127.0.0.1:5002 --run_id 2022-11-06--17-59-10

该脚本返回指向相应实验运行的参数化 URL。

====> ACCESS MLBOARD VIA http://127.0.0.1:8080?rest_endpoint=http%3A//127.0.0.1%3A5001&ws_endpoint=http%3A//127.0.0.1%3A5002&run_id=2022-11-06--17-59-10

请注意,Flask Web 服务以静态方式提供编译后的 React 文件,因此对前端代码的任何更改都不会自动反映出来。作为解决方案,您可以直接通过 yarn 启动 MLboard React 应用程序,并在浏览器中使用相应的 URL 搜索参数调用该 URL。

cd mlgym/src/ml_board/frontend/dashboard

yarn start

由于 MLboard 前端仍在开发中,并且尚未实现所有功能,因此您可以直接在事件存储中分析日志文件。所有消息都按照在WebSocket API 中指定的方式记录。

要实时查看消息,请 cd 到事件存储目录并 tail event_storage.log 文件。

cd event_storage/2022-11-06--17-59-10/
tail -f event_storage.log

MLboard

由于 MLboard 仍在大力开发中,我们想给您一个关于未来即将到来内容的预览。

版权

版权(C)2020 Max Lübbering

有关许可信息,请参阅:https://github.com/mlgym/mlgym/blob/master/LICENSE

项目详情


发布历史 发布通知 | RSS 源

下载文件

下载适用于您平台的应用程序。如果您不确定选择哪一个,请了解更多关于安装包的信息。

源代码分发

mlgym-0.0.75.tar.gz (2.5 MB 查看哈希值)

上传时间 源代码

构建分发

mlgym-0.0.75-py3-none-any.whl (2.5 MB 查看哈希值)

上传时间 Python 3

支持者