跳转到主要内容

将机器学习模型转换为ONNX格式,以便在Windows ML中使用

项目描述

简介

keras2onnx模型转换器允许用户将Keras模型转换为ONNX模型格式。最初,Keras转换器是在onnxmltools项目中开发的。keras2onnx转换器开发被迁移到一个独立的存储库,以支持更多种类的Keras模型并减少混合多个转换器的复杂性。

大多数常见的Keras层都支持转换。请参阅Keras文档tf.keras文档以获取Keras层的详细信息。

Windows机器学习(WinML)用户可以使用WinMLTools,该工具封装了keras2onnx的调用以转换Keras模型。如果您想使用keras2onnx转换器,请参考WinML版本说明以确定您WinML版本的相应ONNX opset编号。

keras2onnx已在Python 3.5、3.6和3.7上进行了测试,支持tensorflow 1.x/2.0/2.1(CI构建)。它不支持Python 2.x

安装

您可以从PyPi安装Keras2ONNX的最新版本:由于某些原因,软件包发布已暂停,请从源代码安装,并且仅在源代码中支持keras或tf.keras对tensorflow 2.x的支持。

pip install keras2onnx

或从源代码安装

pip install -U git+https://github.com/microsoft/onnxconverter-common
pip install -U git+https://github.com/onnx/keras-onnx

在运行转换器之前,请注意,TensorFlow必须安装到您的Python环境中,您可以选择tensorflow/tensorflow-cpu(CPU版本)或tensorflow-gpu(GPU版本)软件包。

注意

自版本1.6.5以来,Keras2ONNX支持TensorFlow 2.0中引入的新Keras子类模型。一些典型的子类模型,如huggingface/transformers,已被转换为ONNX并由ONNXRuntime验证。

自版本2.3以来,多后端Keras(keras.io)停止支持2.0以上的TensorFlow版本。作者建议切换到tf.keras以获取新功能。

多后端Keras和tf.keras

现在keras2onnx转换器支持两种Keras模型类型。如果用户Python环境中的Keras软件包是从Keras.io安装的,并且TensorFlow软件包版本为1.x,则转换器将以keras.io软件包创建的模型的方式转换模型。否则,它将通过tf.keras转换。

如果您想覆盖此行为,请在调用转换器Python API之前指定环境变量TF_KERAS=1。

开发

Keras2ONNX依赖于onnxconverter-common。实际上,此转换器的最新代码需要onnxconverter-common的最新版本,因此如果您从其源代码安装此转换器,请在安装keras2onnx之前以源代码模式安装onnxconverter-common。

验证的预训练Keras模型

大多数Keras模型可以通过调用keras2onnx.convert_keras成功转换,包括CV、GAN、NLP、语音等。请参阅此处的教程。但是,一些具有大量自定义操作的模型需要自定义转换,以下是一些示例,例如YOLOv3Mask RCNN

脚本

从Python脚本转换Keras模型到ONNX非常有用。您可以使用以下API

import keras2onnx
keras2onnx.convert_keras(model, name=None, doc_string='', target_opset=None, channel_first_inputs=None):
    # type: (keras.Model, str, str, int, []) -> onnx.ModelProto
    """
    :param model: keras model
    :param name: the converted onnx model internal name
    :param doc_string:
    :param target_opset:
    :param channel_first_inputs: A list of channel first input.
    :return:
    """

使用以下脚本将keras应用程序模型转换为ONNX,然后执行推理

import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
import keras2onnx
import onnxruntime

# image preprocessing
img_path = 'street.jpg'   # make sure the image is in img_path
img_size = 224
img = image.load_img(img_path, target_size=(img_size, img_size))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# load keras model
from keras.applications.resnet50 import ResNet50
model = ResNet50(include_top=True, weights='imagenet')

# convert to onnx model
onnx_model = keras2onnx.convert_keras(model, model.name)

# runtime prediction
content = onnx_model.SerializeToString()
sess = onnxruntime.InferenceSession(content)
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)

推理结果是与keras模型预测结果model.predict()相对应的列表。将ONNX模型加载到运行时会话的另一种方法是首先保存模型

temp_model_file = 'model.onnx'
keras2onnx.save_model(onnx_model, temp_model_file)
sess = onnxruntime.InferenceSession(temp_model_file)

贡献

我们欢迎以反馈、想法或代码的形式进行贡献。

许可协议

MIT许可证

项目详情


下载文件

下载适合您平台的文件。如果您不确定该选择哪个,请了解更多关于安装包的信息。

源代码分发

此版本没有可用的源代码分发文件。请参阅生成分发存档的教程。

构建分发

keras2onnx-1.7.0-py3-none-any.whl (96.3 kB 查看哈希值)

上传时间 Python 3

由以下提供支持

AWSAWS云计算和安全赞助商DatadogDatadog监控FastlyFastlyCDNGoogleGoogle下载分析MicrosoftMicrosoftPSF赞助商PingdomPingdom监控SentrySentry错误日志StatusPageStatusPage状态页面