自动调整模型
项目描述
麻省理工学院数据到AI实验室的开源项目。
ATM - 自动调整模型
概述
自动调整模型(ATM)是一个易于使用的AutoML系统。简而言之,你给ATM一个分类问题和一个CSV文件数据集,ATM将尝试构建最佳的模型。ATM基于同名论文,该项目是麻省理工学院人机数据交互(HDI)项目的一部分。
安装
要求
ATM已在Python 2.7、3.5和3.6上进行开发和测试
此外,尽管并非强制要求,但强烈建议使用virtualenv,以避免与在运行ATM的系统上安装的其他软件发生冲突。
以下是用python3.6创建虚拟环境的必要命令,用于ATM
pip install virtualenv
virtualenv -p $(which python3.6) atm-venv
之后,您需要执行此命令以激活虚拟环境
source atm-venv/bin/activate
请记住,每次启动新的控制台进行ATM工作时都要执行它!
使用pip安装
创建并激活虚拟环境后,我们建议使用pip来安装ATM
pip install atm
这将从PyPi拉取并安装最新稳定版本。
从源代码安装
或者,在激活虚拟环境后,您可以通过在stable
分支上运行make install
来克隆存储库并从源代码安装它
git clone git@github.com:HDI-Project/ATM.git
cd ATM
git checkout stable
make install
为开发安装
如果您想为项目做出贡献,还需要执行一些额外步骤来使项目准备好进行开发。
首先,请前往项目的GitHub页面,并点击页面右上角的fork
按钮,用自己的用户名创建项目的副本。
之后,克隆您的副本,并从主分支创建一个以您要解决的问题编号命名的分支。
git clone git@github.com:{your username}/ATM.git
cd ATM
git branch issue-xx-cool-new-feature master
git checkout issue-xx-cool-new-feature
最后,使用以下命令安装项目,这将安装一些用于代码检查和测试的附加依赖项。
make install-develop
请确保在开发过程中定期运行make lint
和make test
命令。
数据格式
ATM输入始终是具有以下特征的CSV文件
- 它使用单个逗号
,
作为分隔符。 - 它的第一行是标题行,包含列名。
- 其中有一个包含需要预测的目标变量的列。
- 其余的列都是将用于预测目标列的变量或特征。
- 每一行对应于单个完整的学习样本。
以下是一个有效CSV文件的前5行示例,具有4个特征和一个名为class
的目标列
feature_01,feature_02,feature_03,feature_04,class
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
此CSV可以作为本地文件系统路径传递给ATM,也可以作为完整的AWS S3 Bucket和路径规范传递,或作为URL。
您可以在AWS中的atm-data S3 Bucket中找到一些演示数据集。
快速入门
在这个简短的教程中,我们将通过一系列步骤引导您通过探索其Python API来入门ATM。
1. 获取示例数据
运行ATM的第一步是获取在教程其余部分中将使用的示例数据集。
对于这个示例,我们将使用atm-data存储桶中的污染csv,您可以通过浏览器从这里下载,或者使用以下命令
atm download_demo pollution_1.csv
2. 创建ATM实例
获得示例数据后,首先要做的就是创建ATM实例。
from atm import ATM
atm = ATM()
默认情况下,如果没有提供任何参数,ATM实例将在当前工作目录中创建一个名为atm.db
的SQLite数据库。
如果您想连接到SQL数据库或更改SQLite数据库的位置,请参阅API参考以获取完整选项列表。
3. 寻找最佳模型
一旦您准备好了ATM实例,您可以使用atm.run
方法开始搜索能更好地预测您CSV文件目标列的模型。
此函数必须提供您的CSV文件路径,可以是本地文件系统路径、HTTP或S3资源的URL。
例如,如果我们已经在我们当前工作目录中下载了pollution_1.csv文件,我们可以这样调用run
results = atm.run(train_path='pollution_1.csv')
或者,我们可以使用文件的HTTPS URL,让ATM为我们下载CSV文件
results = atm.run(train_path='https://atm-data.s3.amazonaws.com/pollution_1.csv')
作为最后一个选项,如果文件在S3 Bucket中,我们可以通过传递s3://{bucket}/{key}
格式的URI来下载它
results = atm.run(train_path='s3://atm-data/pollution_1.csv')
为了使私有S3 Bucket能够使用,请确保您已经配置了AWS凭证文件,或者已经创建了包含access_key
和secret_key
参数的ATM
实例。
这个run
调用将启动一个称为Datarun
的过程,在测试和调整不同的模型时将显示进度条。
Processing dataset demos/pollution_1.csv
100%|##########################| 100/100 [00:10<00:00, 6.09it/s]
此过程结束后,将打印一条消息,说明Datarun
已结束。然后我们可以探索results
对象。
4. 探索结果
一旦Datarun完成,我们可以通过几种方式探索results
对象
a. 获取Datarun摘要
describe
方法将返回Datarun执行摘要
results.describe()
这将打印出类似于以下内容的Datarun简短描述
Datarun 1 summary:
Dataset: 'demos/pollution_1.csv'
Column Name: 'class'
Judgment Metric: 'f1'
Classifiers Tested: 100
Elapsed Time: 0:00:07.638668
b. 获取最佳分类器摘要
get_best_classifier
方法将打印出关于在Datarun期间找到的最佳分类器的信息,包括所使用的方法和找到的最佳超参数
results.get_best_classifier()
输出将类似于以下内容
Classifier id: 94
Classifier type: knn
Params chosen:
n_neighbors: 13
leaf_size: 38
weights: uniform
algorithm: kd_tree
metric: manhattan
_scale: True
Cross Validation Score: 0.858 +- 0.096
Test Score: 0.714
c. 探索分数
get_scores
方法将返回一个包含所有在Datarun期间测试的分类器的信息的pandas.DataFrame
,包括它们的交叉验证分数和它们的pickle模型的位置。
scores = results.get_scores()
分数数据框的内容应类似于以下内容
cv_judgment_metric cv_judgment_metric_stdev id test_judgment_metric rank
0 0.8584126984 0.0960095737 94 0.7142857143 1.0
1 0.8222222222 0.0623609564 12 0.6250000000 2.0
2 0.8147619048 0.1117618135 64 0.8750000000 3.0
3 0.8139393939 0.0588721670 68 0.6086956522 4.0
4 0.8067754468 0.0875180564 50 0.6250000000 5.0
...
5. 进行预测
一旦我们找到并探索了最佳分类器,我们就会想要用它来进行预测。
为了做到这一点,我们需要遵循几个步骤
a. 导出最佳分类器
export_best_classifier
方法可以用来使用pickle将最佳分类器模型序列化并保存到指定的位置
results.export_best_classifier('path/to/model.pkl')
如果分类器已正确保存,将打印一条消息表示如此
Classifier 94 saved as path/to/model.pkl
如果您提供的路径已存在,可以通过添加参数force=True
来覆盖它。
b. 加载导出的模型
一旦导出,您可以通过调用atm.Model
类的load
方法将其加载回来,并传递模型已保存的位置
from atm import Model
model = Model.load('path/to/model.pkl')
一旦您加载了您的模型,您可以将新数据传递给其predict
方法来进行预测
import pandas as pd
data = pd.read_csv(demo_datasets['pollution'])
predictions = model.predict(data.head())
接下来是什么?
有关ATM及其所有可能性和功能的更多详细信息,请访问文档网站。
在那里您可以了解其命令行界面和其REST API,以及如何为ATM做出贡献以帮助我们开发新功能或酷炫的想法。
致谢
ATM是麻省理工学院数据到人工智能实验室的一个开源项目,多年来由以下团队构建和维护:
- Bennett Cyphers bcyphers@mit.edu
- Thomas Swearingen swearin3@msu.edu
- Carles Sala csala@csail.mit.edu
- Plamen Valentinov plamen@pythiac.com
- Kalyan Veeramachaneni kalyan@mit.edu
- Micah Smith micahjsmith@gmail.com
- Laura Gustafson lgustaf@mit.edu
- Kiran Karra kiran.karra@gmail.com
- Max Kanter kmax12@gmail.com
- Alfredo Cuesta-Infante alfredo.cuesta@urjc.es
- Favio André Vázquez favio.vazquezp@gmail.com
- Matteo Hoch minime@hochweb.com
引用 ATM
如果您使用 ATM,请考虑引用以下论文
Thomas Swearingen, Will Drevo, Bennett Cyphers, Alfredo Cuesta-Infante, Arun Ross, Kalyan Veeramachaneni. ATM: A distributed, collaborative, scalable system for automated machine learning. IEEE BigData 2017, 151-162
BibTeX条目
@inproceedings{DBLP:conf/bigdataconf/SwearingenDCCRV17,
author = {Thomas Swearingen and
Will Drevo and
Bennett Cyphers and
Alfredo Cuesta{-}Infante and
Arun Ross and
Kalyan Veeramachaneni},
title = {{ATM:} {A} distributed, collaborative, scalable system for automated
machine learning},
booktitle = {2017 {IEEE} International Conference on Big Data, BigData 2017, Boston,
MA, USA, December 11-14, 2017},
pages = {151--162},
year = {2017},
crossref = {DBLP:conf/bigdataconf/2017},
url = {https://doi.org/10.1109/BigData.2017.8257923},
doi = {10.1109/BigData.2017.8257923},
timestamp = {Tue, 23 Jan 2018 12:40:42 +0100},
biburl = {https://dblp.org/rec/bib/conf/bigdataconf/SwearingenDCCRV17},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
相关项目
BTB
BTB,用于贝叶斯调优和探索,是HD项目下开发的核心AutoML库。BTB通过通用API公开了多个用于超参数选择和调优的方法。它允许领域专家轻松扩展现有方法并添加新的方法。BTB是ATM的一个核心部分,这两个项目是同时开发的,但它旨在实现无关性,应该对广泛的超参数选择任务有用。
Featuretools
Featuretools 是一个用于自动特征工程的Python库。它可以用于准备ATM的原始事务和关系数据集。它由 Feature Labs 创建和维护,也是 人类数据交互项目 的一部分。
历史
0.2.2 (2019-07-30)
新功能
0.2.1 (2019-06-24)
新功能
- Rest API跨源资源共享(CORS) - 问题 #146 由 @pvk-developer
0.2.0 (2019-05-29)
新Python API
新功能
- 用于Python中ATM使用的新API - 问题 #142 由 @pvk-developer 和 @csala
- 改进文档 - 问题 #142 由 @pvk-developer 和 @csala
- 代码清理 - 问题 #102 由 @csala
- 确保可以从S3下载数据集 - 问题 #137 由 @pvk-developer
- 改为PyMySQL以删除libmysqlclient-dev系统依赖 - 问题 #136 由 @pvk-developer 和 @csala
0.1.2 (2019-05-07)
REST API和集群管理。
新功能
- REST API服务器 - 问题 #82 和 #132 由 @RogerTangos, @pvk-developer 和 @csala
- 添加集群管理命令以作为后台进程启动和停止服务器和多个工作进程 - 问题 #130 由 @pvk-developer 和 @csala
- 添加TravisCI并将文档迁移到GitHub Pages - 问题 #129 由 @pvk-developer
0.1.1 (2019-04-02)
首次发布于PyPi。
新功能
- 升级到最新BTB。
- 新的命令行界面。
0.1.0 (2018-05-04)
- 首次发布。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。
源码分发
构建分发
atm-0.2.2.tar.gz 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 9a5ab26f14bff3f08665b1d6ae0bfa5e8f0848860a3977664b362ea03ceb9352 |
|
MD5 | b9842eb23a4074495d3f138a232e5e59 |
|
BLAKE2b-256 | 5b5b77e50b27ef2b0b2f4be0c9dfe6f4cee9c37749a393e3e1c01c9aa33c892b |
atm-0.2.2-py2.py3-none-any.whl 的哈希
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 805243fb6b1e3c990114f26a78f1373b56cb169a2dbf656cefc41c4348860d6c |
|
MD5 | 81670be6eab3c2be495ec8c9ca5a4df0 |
|
BLAKE2b-256 | 1a234ff31ca9695473342140369274f5d11ed0fec8204dac3a6e469797ad046f |