跳转到主要内容

ONNX的数组(和numpy)API

项目描述

https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api https://badge.fury.io/py/onnx-array-api.svg GitHub Issues MIT License size https://img.shields.io/badge/code%20style-black-000000.svg https://codecov.io/gh/sdpython/onnx-array-api/branch/main/graph/badge.svg?token=Wb9ZGDta8J

onnx-array-api 实现了创建自定义ONNX图的API。目标是加快转换库的实现速度。该库已发布在 pypi/onnx-array-api,其文档发布在 创建ONNX图的API

Numpy API

第一个与 Numpy API 匹配。它使用户能够将遵循Numpy API编写的函数转换为ONNX,并执行它。

import numpy as np
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

def l1_loss(x, y):
    return absolute(x - y).sum()


def l2_loss(x, y):
    return ((x - y) ** 2).sum()


def myloss(x, y):
    return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])


jitted_myloss = jit_onnx(myloss)

x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)

res = jitted_myloss(x, y)
print(res)

print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
[0.042]
opset: domain='' version=18
input: name='x0' type=dtype('float32') shape=['', '']
input: name='x1' type=dtype('float32') shape=['', '']
Sub(x0, x1) -> r__0
  Abs(r__0) -> r__1
    ReduceSum(r__1, keepdims=0) -> r__2
output: name='r__2' type=dtype('float32') shape=None

它还支持即时模式

import numpy as np
from onnx_array_api.npx import absolute, eager_onnx


def l1_loss(x, y):
    err = absolute(x - y).sum()
    print(f"l1_loss={err.numpy()}")
    return err


def l2_loss(x, y):
    err = ((x - y) ** 2).sum()
    print(f"l2_loss={err.numpy()}")
    return err


def myloss(x, y):
    return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])


eager_myloss = eager_onnx(myloss)

x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)

res = eager_myloss(x, y)
print(res)
l1_loss=[0.04]
l2_loss=[0.002]
[0.042]

轻量API

第二个API或 轻量API 倾向于在一行内完成所有操作。它受到 逆波兰表示法 的启发。欧几里得距离看起来如下所示

import numpy as np
from onnx_array_api.light_api import start
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot

model = (
    start()
    .vin("X")
    .vin("Y")
    .bring("X", "Y")
    .Sub()
    .rename("dxy")
    .cst(np.array([2], dtype=np.int64), "two")
    .bring("dxy", "two")
    .Pow()
    .ReduceSum()
    .rename("Z")
    .vout()
    .to_onnx()
)

GraphBuilder API

几乎每个转换库(将机器学习模型转换为ONNX)都在实现自己的图构建器,并为其需求定制它。它处理一些常见任务,如为中间结果命名,加载、保存ONNX模型。它也可以用来扩展现有图。

import numpy as np
from onnx_array_api.graph_api  import GraphBuilder

g = GraphBuilder()
g.make_tensor_input("X", np.float32, (None, None))
g.make_tensor_input("Y", np.float32, (None, None))
r1 = g.make_node("Sub", ["X", "Y"])  # the name given to the output is given by the class,
                                     # it ensures the name is unique
init = g.make_initializer(np.array([2], dtype=np.int64))  # the class automatically
                                                          # converts the array to a tensor
r2 = g.make_node("Pow", [r1, init])
g.make_node("ReduceSum", [r2], outputs=["Z"])  # the output name is given because
                                               # the user wants to choose the name
g.make_tensor_output("Z", np.float32, (None, None))

onx = g.to_onnx()  # final conversion to onnx

项目详细信息


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分发

onnx-array-api-0.2.0.tar.gz (207.2 kB 查看哈希值)

上传时间

构建分发

onnx_array_api-0.2.0-py3-none-any.whl (229.4 kB 查看哈希值)

上传时间 Python 3