跳转到主要内容

扩展onnx参考实现和onnxruntime支持的算子列表,或在C++中实现更快的版本。

项目描述

https://github.com/sdpython/onnx-extended/raw/main/_doc/_static/logo.png

onnx-extended: onnx和onnxruntime的扩展

https://dev.azure.com/xavierdupre3/onnx-extended/_apis/build/status/sdpython.onnx-extended https://badge.fury.io/py/onnx-extended.svg GitHub Issues MIT License size https://img.shields.io/badge/code%20style-black-000000.svg

onnx-extended 扩展了 onnx 参考实现和 onnxruntime 支持的操作符列表,或者使用 C++ 实现更快版本的函数。文档请参阅 onnx-extended。源代码可在 github/onnx-extended 上找到。

使用现有操作符的 C++ 实现

import timeit
import numpy as np
from onnx import TensorProto
from onnx.helper import (
    make_graph,
    make_model,
    make_node,
    make_opsetid,
    make_tensor_value_info,
)
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from onnx_extended.ext_test_case import measure_time
from onnx_extended.reference import CReferenceEvaluator


X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None, None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None, None])
B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None, None, None])
W = make_tensor_value_info("W", TensorProto.FLOAT, [None, None, None, None])
node = make_node(
    "Conv",
    ["X", "W", "B"],
    ["Y"],
    pads=[1, 1, 1, 1],
    dilations=[1, 1],
    strides=[2, 2],
)
graph = make_graph([node], "g", [X, W, B], [Y])
onnx_model = make_model(graph, opset_imports=[make_opsetid("", 16)])

sH, sW = 64, 64
X = np.arange(sW * sH).reshape((1, 1, sH, sW)).astype(np.float32)
W = np.ones((1, 1, 3, 3), dtype=np.float32)
B = np.array([[[[0]]]], dtype=np.float32)

sess1 = ReferenceEvaluator(onnx_model)
sess2 = CReferenceEvaluator(onnx_model)  # 100 times faster

expected = sess1.run(None, {"X": X, "W": W, "B": B})[0]
got = sess2.run(None, {"X": X, "W": W, "B": B})[0]
diff = np.abs(expected - got).max()
print(f"difference: {diff}")

f1 = lambda: sess1.run(None, {"X": X, "W": W, "B": B})[0]
f2 = lambda: sess2.run(None, {"X": X, "W": W, "B": B})[0]
print("onnx:", timeit.timeit(f1, globals=globals(), number=5))
print("onnx-extended:", timeit.timeit(f2, globals=globals(), number=5))
difference: 0.0
onnx: 0.024006774998269975
onnx-extended: 0.0002316169993719086

使用 CUDA、openmp、eigen、onnxruntime 编译

该软件包还包含一些示例,说明如何使用 C++ 函数(pybind11cython)以及 openmpeigen(带或不带 CUDA)进行编译。它还展示了如何在 C++ 中为 onnxruntime 创建自定义操作符。

pypi/onnx-extended 上发布的版本仅适用于 CPU。需要手动编译才能启用使用 CUDA 的代码。如果找到 CUDA,则构建将自动链接 CUDA。如果没有,某些扩展可能不可用。

python setup.py build_ext --inplace
# pip install -e .

可以使用特定的 CUDA 版本

python setup.py build_ext --inplace --cuda-version=11.8
# or (not working yet)
# pip install -e . --config-settings="--cuda-version=11.8"
# pip install -e . --global-option="--cuda-version=11.8"
export USE_CUDA=11.8
pip install -e .

NVTX 可以通过以下命令启用

python setup.py build_ext --inplace --use_nvtx 1
# or (not working yet)
# pip install -e . --config-settings="--use_nvtx=1"
pip install -e . --global-option "--use_nvtx=1"

onnxruntime 的实验性 cython 绑定

Python onnxruntime 软件包依赖于 pybind11 来公开其功能。 onnx-extended 尝试在 onnxruntime 的 C/C++ API 周围构建一个 cython 包装器。cython 依赖于 Python C API,并且比 pybind11 快。当 onnxruntime 用于小型图和张量时,这种差异可能很大。

为 onnxruntime 定制内核

onnxruntime 提供了一个 API,用于为现有或新的 onnx 操作符添加自定义实现。以下是一个 CPU 的示例。

from onnxruntime import InferenceSession, SessionOptions
from onnx_extended.ortops.optim.cpu import get_ort_ext_libs

r = get_ort_ext_libs()
opts = SessionOptions()
if r is not None:
    opts.register_custom_ops_library(r[0])

sess_cus = InferenceSession(
    onx_modified.SerializeToString(), opts, providers=["CPUExecutionProvider"]
)

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源代码分发

onnx_extended-0.3.0.tar.gz (14.7 MB 查看哈希值)

上传时间 源代码

构建分发

onnx_extended-0.3.0-cp311-cp311-win_amd64.whl (67.1 MB 查看哈希值)

上传时间 CPython 3.11 Windows x86-64

onnx_extended-0.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (25.8 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.28+ x86-64

onnx_extended-0.3.0-cp310-cp310-win_amd64.whl (67.0 MB 查看哈希值)

上传时间 CPython 3.10 Windows x86-64

onnx_extended-0.3.0-cp310-cp310-manylinux_2_28_x86_64.whl (25.7 MB 查看哈希值)

上传于 CPython 3.10 manylinux: glibc 2.28+ x86-64

由以下机构支持