将机器学习模型转换为ONNX格式,以便在Windows ML中使用
项目描述
简介
keras2onnx模型转换器允许用户将Keras模型转换为ONNX模型格式。最初,Keras转换器是在onnxmltools项目中开发的。keras2onnx转换器开发被迁移到一个独立的存储库,以支持更多种类的Keras模型并减少混合多个转换器的复杂性。
大多数常见的Keras层都支持转换。请参阅Keras文档或tf.keras文档以获取Keras层的详细信息。
Windows机器学习(WinML)用户可以使用WinMLTools,该工具封装了keras2onnx的调用以转换Keras模型。如果您想使用keras2onnx转换器,请参考WinML版本说明以确定您WinML版本的相应ONNX opset编号。
keras2onnx已在Python 3.5、3.6和3.7上进行了测试,支持tensorflow 1.x/2.0/2.1(CI构建)。它不支持Python 2.x。
安装
您可以从PyPi安装Keras2ONNX的最新版本:由于某些原因,软件包发布已暂停,请从源代码安装,并且仅在源代码中支持keras或tf.keras对tensorflow 2.x的支持。
pip install keras2onnx
或从源代码安装
pip install -U git+https://github.com/microsoft/onnxconverter-common
pip install -U git+https://github.com/onnx/keras-onnx
在运行转换器之前,请注意,TensorFlow必须安装到您的Python环境中,您可以选择tensorflow/tensorflow-cpu(CPU版本)或tensorflow-gpu(GPU版本)软件包。
注意
自版本1.6.5以来,Keras2ONNX支持TensorFlow 2.0中引入的新Keras子类模型。一些典型的子类模型,如huggingface/transformers,已被转换为ONNX并由ONNXRuntime验证。
自版本2.3以来,多后端Keras(keras.io)停止支持2.0以上的TensorFlow版本。作者建议切换到tf.keras以获取新功能。
多后端Keras和tf.keras
现在keras2onnx转换器支持两种Keras模型类型。如果用户Python环境中的Keras软件包是从Keras.io安装的,并且TensorFlow软件包版本为1.x,则转换器将以keras.io软件包创建的模型的方式转换模型。否则,它将通过tf.keras转换。
如果您想覆盖此行为,请在调用转换器Python API之前指定环境变量TF_KERAS=1。
开发
Keras2ONNX依赖于onnxconverter-common。实际上,此转换器的最新代码需要onnxconverter-common的最新版本,因此如果您从其源代码安装此转换器,请在安装keras2onnx之前以源代码模式安装onnxconverter-common。
验证的预训练Keras模型
大多数Keras模型可以通过调用keras2onnx.convert_keras
成功转换,包括CV、GAN、NLP、语音等。请参阅此处的教程。但是,一些具有大量自定义操作的模型需要自定义转换,以下是一些示例,例如YOLOv3和Mask RCNN。
脚本
从Python脚本转换Keras模型到ONNX非常有用。您可以使用以下API
import keras2onnx
keras2onnx.convert_keras(model, name=None, doc_string='', target_opset=None, channel_first_inputs=None):
# type: (keras.Model, str, str, int, []) -> onnx.ModelProto
"""
:param model: keras model
:param name: the converted onnx model internal name
:param doc_string:
:param target_opset:
:param channel_first_inputs: A list of channel first input.
:return:
"""
使用以下脚本将keras应用程序模型转换为ONNX,然后执行推理
import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
import keras2onnx
import onnxruntime
# image preprocessing
img_path = 'street.jpg' # make sure the image is in img_path
img_size = 224
img = image.load_img(img_path, target_size=(img_size, img_size))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
# load keras model
from keras.applications.resnet50 import ResNet50
model = ResNet50(include_top=True, weights='imagenet')
# convert to onnx model
onnx_model = keras2onnx.convert_keras(model, model.name)
# runtime prediction
content = onnx_model.SerializeToString()
sess = onnxruntime.InferenceSession(content)
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)
推理结果是与keras模型预测结果model.predict()
相对应的列表。将ONNX模型加载到运行时会话的另一种方法是首先保存模型
temp_model_file = 'model.onnx'
keras2onnx.save_model(onnx_model, temp_model_file)
sess = onnxruntime.InferenceSession(temp_model_file)
贡献
我们欢迎以反馈、想法或代码的形式进行贡献。
许可协议
项目详情
下载文件
下载适合您平台的文件。如果您不确定该选择哪个,请了解更多关于安装包的信息。