AI工具包
项目描述
ai-toolkit
动机
当我们在机器学习项目上工作时,尤其是在监督学习方面,往往存在大量的重复代码,因为在每个项目中,我们总是想找到一种方法来检查我们的工作,在tensorboard中可视化损失曲线,添加额外的指标,并查看示例输出。有些项目我们能够做得比其他项目更好。理想情况下,我们希望有一种方法将所有这些代码合并到一个地方。
问题是PyTorch示例通常不太相似。像大多数数据探索一样,我们希望能够修改代码库中的任何部分来处理不同的损失度量、不同类型的数据或基于数据维度的不同可视化。将所有内容合并到一个存储库中通常会使底层逻辑过于复杂(例如,使训练循环难以阅读)。我们希望在极其简洁/可读的代码和易于添加额外功能之间取得平衡。
该项目是为希望从开始就拥有完整功能机器学习管道功能的开发人员或机器学习科学家而设计的。每个项目都包含一致的样式、对日志记录、指标和检查点/从检查点恢复训练有偏好的处理方式。它还无缝集成到Google Colab和AWS/Google Cloud GPU。
试试看!
你应该做的第一件事是进入output_*/文件夹中的一个,并尝试训练一个模型。我们目前有以下模型
显著特性
- 在 train.py 中,代码对所有模型执行一些验证检查,以确保您没有混淆批维度。
- 尝试停止它,然后经过几个训练周期后重新启动 - 它应该从相同的位置继续训练。
- 在 tensorboard 上,损失曲线应该已经无缝地在运行之间绘制。
- 所有检查点都应可在 checkpoints/ 目录中找到,其中包含激活层、输入数据和最佳模型。
- 通过指定 configs/ 目录中的文件来轻松安排运行。
评估标准
目标是让这个存储库包含一系列清洁的机器学习示例,不同层次的理解,我可以从中提取并用作示例、测试模型等。我基本上想要收集我找到或过去使用过的所有最佳实践代码片段,并将它们模块化,以便可以轻松导入或导出以供以后使用。
目标不是将其构建为基于 PyTorch 的机器学习框架,而是专注于单个研究人员/开发者工作流程,使其非常容易开始工作。非常适合 Kaggle 比赛、简单的数据探索或尝试不同的模型。
本存储库成功的粗略评估指标是下载数据后我能够多快开始处理 Kaggle 挑战:对数据进行洞察、分析数据分布、运行基线和微调模型、获取损失曲线和图表。
当前工作流程
- 将数据添加到您的
data/
文件夹,并编辑datasets/
中的相应 DataasetLoader。 - 将您的配置和模型添加到
configs/
和models/
。 - 运行
train.py
,它将在同一文件夹中保存模型检查点、输出预测和 tensorboards。 - 使用
tensorboard --logdir=checkpoints/
在checkpoints/
文件夹中启动 tensorboard。 - 使用
python train.py --checkpoint=<checkpoint name>
开始和停止训练。代码应自动从上一个训练周期恢复并继续向之前的 tensorboard 记录。 - 运行
python test.py --checkpoint=<checkpoint name>
获取最终预测。
目录结构
- checkpoints/ (仅在运行 train.py 后创建)
- data/
- configs/
- ai_toolkit/
- datasets/
- losses/
- metrics/
- models/
- layers/
- ...
- visualizations/
- args.py (手动修改默认超参数)
- metric_tracker.py
- test.py
- train.py
- util.py
- verify.py
- viz.py (占位符,如有必要创建更多可视化)
- tests/
目标工作流程
- 将数据移动到
data/
。 - 填写
preprocess.py
和dataset.py
。 (通过运行python viz.py
探索数据) - 更改
args.py
以指定输入/输出维度、批量大小等。 - 运行
train.py
,它将在同一文件夹中保存模型检查点、输出预测和 tensorboards。此外,还在 tmux 会话中自动启动 tensorboard 服务器。在任何时候都可以恢复训练。 - 运行
test.py
获取最终预测。
项目详情
下载文件
下载您平台上的文件。如果您不确定选择哪个,请了解有关 安装包 的更多信息。
源分布
构建分发
ai_toolkit-0.0.2.tar.gz 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 5665e6ea0f0f20aef43c3e9d4a4d2edb5bac64df9fe56e794a44b6ed3ee93bea |
|
MD5 | 3d557a5cfafbf269d2b006ab3a81010b |
|
BLAKE2b-256 | 21c1a8b8c6acd94b82e20eb7c7f2d1f267dd1c55efb901bb1f731f696cb13e54 |
ai_toolkit-0.0.2-py3-none-any.whl 的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 4597ff6e5e5245e66677dce45f549030f81d5878480bd1936e42f356f2bee3ba |
|
MD5 | 877d01b287e9c26246501d0c564aeb3a |
|
BLAKE2b-256 | 52746714bb0a32c1e08890d016cd26c63199ed427881476cf0721233b60e4ec2 |