跳转到主要内容

将scikit-learn模型转换为ONNX

项目描述

介绍

sklearn-onnxscikit-learn 模型转换为 ONNX 格式。一旦转换为 ONNX 格式,您可以使用如 ONNX Runtime 之类的工具进行高性能评分。所有转换器均通过 onnxruntime 进行测试。任何外部转换器都可以注册以转换 scikit-learn 流程,包括来自外部库的模型或转换器。

文档

完整文档包括教程可在 https://onnx.org.cn/sklearn-onnx/ 找到。 支持的 scikit-learn 模型 最后支持的 opset 是 21。

您也可以在 现有问题 中找到答案或提交新问题。

安装

您可以从 PyPi 安装。

pip install skl2onnx

或者您可以从源代码安装,以获取最新的更改。

pip install git+https://github.com/onnx/sklearn-onnx.git

入门

# Train a model.
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

# Convert into ONNX format.
from skl2onnx import to_onnx

onx = to_onnx(clr, X[:1])
with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

# Compute the prediction with onnxruntime.
import onnxruntime as rt

sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0]

贡献

我们欢迎以反馈、想法或代码的形式进行贡献。

许可协议

Apache License v2.0

项目详情


下载文件

下载适合您平台的应用程序。如果您不确定选择哪个,请了解更多关于 安装软件包 的信息。

源代码发行版

skl2onnx-1.17.0.tar.gz (932.0 kB 查看哈希值)

上传时间 源代码

构建发行版

skl2onnx-1.17.0-py2.py3-none-any.whl (298.4 kB 查看哈希值)

上传时间 Python 2 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面