MLgym,一个用于在研究中进行分布式和可重复的机器学习模型训练的Python框架。
项目描述
一个功能丰富的深度学习框架,提供实验的全可重复性。
可重复性是深度学习(研究)中一个常见问题,模型通常在Jupyter笔记本中实现,或者每个新项目从头开始实现整个训练和评估流程。实验设置的缺乏标准化和重复的样板代码阻碍了可重复性。
MLgym旨在通过将实验设置与代码分离,并提供模型训练、模型评估、实验日志记录、检查点以及实验分析的整个基础设施来提高可重复性。
具体来说,MLgym提供了一套可扩展的机器学习组件(例如,训练器、评估器、损失函数等)。该框架根据在配置文件中指定的参数动态实例化这些组件,该配置文件描述了整个实验设置(即训练和评估流程)。实验设置与代码的分离最大化了ML实验的复现性和可解释性。机器学习组件显著减少了实施工作,让您能够专注于自己的想法。
此外,MLgym还提供以下关键特性
-
组件注册以注册自定义组件及其依赖项。
-
热启动,允许在崩溃后恢复训练
-
可自定义的检查点策略
-
MLboard webservice通过订阅WebSocket日志环境进行实验跟踪/分析(在线和离线)
-
支持大规模、多GPU训练,支持网格搜索、嵌套交叉验证和交叉验证
-
通过WebSocket和事件源进行分布式日志记录,允许位置无关的日志记录和完全可重复性
-
在配置文件中定义训练和评估流程,实现实验设置与代码的分离。
请注意,目前此代码应被视为实验性的,并且尚未准备好投入生产。
安装
安装 MLgym 有两种选择,最简单的方法是从 pip 仓库安装框架。
pip install mlgym
对于最新版本,可以直接通过在根目录下运行 cd
并执行以下操作来从源安装:
pip install src/
使用方法
我们提供了一个易于使用的示例,让您可以运行 MLgym 的 实验设置。
在运行实验之前,我们需要设置 MLboard 记录环境,即 WebSocket 服务和 RESTful Web 服务。MLgym 通过 WebSocket API 记录训练/评估进度和评估结果,允许 MLboard 前端接收实时更新。RESTful Web 服务提供接收检查点和实验设置的端点。有关这两个 API 的完整规范,请参阅此处。
我们分别在端口 5001 和 5002 上启动 WebSocket 服务和 RESTful Web 服务。如果需要,可以自由选择不同的端口。同样,我们将 event_storage
文件夹指定为本地事件存储文件夹。请注意,要从不同的端口访问 WebSocket 服务,我们需要指定允许的 CORS 原因。在这个例子中,我们只使用 MLboard 前端从 127.0.0.1:8080 本地使用 WebSocket 服务。
ml_board_ws_endpoint --host 127.0.0.1 --port 5002 --event_storage_path event_storage --cors_allowed_origins http://127.0.0.1:8080
ml_board_rest_endpoint --port 5001 --event_storage_path event_storage
接下来,我们运行实验设置。我们 cd
到示例文件夹,并使用带有相应路径的参数 gs_config_path
运行 run.py
。参数 process_count
指定要并行运行的实验数量。num_epochs
限制训练模型的最大轮数。如果模型性能在一段时间内没有显著提高,则 gs_config.yml
中定义的检查点策略将提前停止训练。
cd mlgym/example/grid_search_example
python run.py --process_count 3 \
--text_logging_path general_logging/ \
--num_epochs 10 \
--websocket_logging_servers http://127.0.0.1:5002 \
--gs_rest_api_endpoint http://127.0.0.1:5001 \
train \
--gs_config_path gs_config.yml
要可视化实时更新,我们运行 MLboard 前端。我们指定提供前端和 REST Web 服务以及 WebSocket 服务的端点的服务器主机和端口。参数 run_id
指的是我们想要分析的实验运行,具体取决于您的案例。每个实验运行都存储在 event_storage
路径下的单独文件夹中。文件夹名称指的是相应的实验运行 ID。
ml_board --ml_board_host 127.0.0.1 --ml_board_port 8080 --rest_endpoint http://127.0.0.1:5001 --ws_endpoint http://127.0.0.1:5002 --run_id 2022-11-06--17-59-10
该脚本返回指向相应实验运行的参数化 URL。
====> ACCESS MLBOARD VIA http://127.0.0.1:8080?rest_endpoint=http%3A//127.0.0.1%3A5001&ws_endpoint=http%3A//127.0.0.1%3A5002&run_id=2022-11-06--17-59-10
请注意,Flask Web 服务以静态方式提供编译后的 React 文件,因此对前端代码的任何更改都不会自动反映出来。作为解决方案,您可以直接通过 yarn 启动 MLboard React 应用程序,并在浏览器中使用相应的 URL 搜索参数调用该 URL。
cd mlgym/src/ml_board/frontend/dashboard
yarn start
由于 MLboard 前端仍在开发中,并且尚未实现所有功能,因此您可以直接在事件存储中分析日志文件。所有消息都按照在WebSocket API 中指定的方式记录。
要实时查看消息,请 cd
到事件存储目录并 tail
event_storage.log
文件。
cd event_storage/2022-11-06--17-59-10/
tail -f event_storage.log
MLboard
由于 MLboard 仍在大力开发中,我们想给您一个关于未来即将到来内容的预览。
版权
版权(C)2020 Max Lübbering
有关许可信息,请参阅:https://github.com/mlgym/mlgym/blob/master/LICENSE
项目详情
下载文件
下载适用于您平台的应用程序。如果您不确定选择哪一个,请了解更多关于安装包的信息。