使用Python子集自然编写ONNX函数和模型
项目描述
ONNX Script
ONNX Script允许开发人员使用Python子集自然地编写ONNX函数和模型。ONNX Script是
- 表达性强: 允许编写所有ONNX函数。
- 简单简洁: 函数代码自然简单。
- 可调试: 允许进行急切模式评估,从而提供更愉悦的ONNX模型调试体验。
然而请注意,ONNX Script不打算支持整个Python语言。
设计概述
ONNX Script提供了一些主要功能,用于编写和调试ONNX模型和函数
-
一个转换器,它将Python ONNX Script函数转换为ONNX图,通过遍历Python抽象语法树来构建函数的等效ONNX图。
-
一个逆转换器,将ONNX模型和函数转换为ONNX Script。此功能可用于完全往返ONNX Script ↔ ONNX图。
-
一个运行时层,允许这些函数以“贪婪模式”进行评估。此功能目前依赖于ONNX 运行时来执行每个ONNX 操作符,并且正在进行一个仅支持 Python 的 ONNX 参考运行时的开发。
请注意,运行时旨在帮助理解和调试函数定义。性能不是此处的目标。
安装 ONNX 脚本
pip install --upgrade onnxscript-preview
开发安装
pip install onnx onnxruntime pytest
git clone https://github.com/microsoft/onnxscript
cd onnxscript
pip install -e .
运行单元测试
pytest onnxscript
示例
import onnx
# We use ONNX opset 15 to define the function below.
from onnxscript import FLOAT
from onnxscript import opset15 as op
from onnxscript import script
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def onnx_hardmax(X, axis: int):
"""Hardmax is similar to ArgMax, with the result being encoded OneHot style."""
# The type annotation on X indicates that it is a float tensor of
# unknown rank. The type annotation on axis indicates that it will
# be treated as an int attribute in ONNX.
#
# Invoke ONNX opset 15 op ArgMax.
# Use unnamed arguments for ONNX input parameters, and named
# arguments for ONNX attribute parameters.
argmax = op.ArgMax(X, axis=axis, keepdims=False)
xshape = op.Shape(X, start=axis)
# use the Constant operator to create constant tensors
zero = op.Constant(value_ints=[0])
depth = op.GatherElements(xshape, zero)
empty_shape = op.Constant(value_ints=[0])
depth = op.Reshape(depth, empty_shape)
values = op.Constant(value_ints=[0, 1])
cast_values = op.CastLike(values, X)
return op.OneHot(argmax, depth, cast_values, axis=axis)
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
matmul = op.MatMul(X, Wt) + Bias
return onnx_hardmax(matmul, axis=1)
# onnx_model is an in-memory ModelProto
onnx_model = sample_model.to_model_proto()
# Save the ONNX model at a given path
onnx.save(onnx_model, "sample_model.onnx")
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print(f"The model is invalid: {e}")
else:
print("The model is valid!")
装饰器解析函数代码,将其转换为中间表示形式。如果失败,它会产生一个错误消息,指示错误检测到的行。如果成功,中间表示形式可以转换为类型为 FunctionProto
的 ONNX 图结构
Hardmax.to_function_proto()
返回一个FunctionProto
贪婪模式评估
贪婪模式主要用于调试和验证中间结果是否符合预期。上面定义的函数可以如下调用,以贪婪评估模式执行
import numpy as np
v = np.array([[0, 1], [2, 3]], dtype=np.float32)
result = Hardmax(v)
更多示例可以在 docs/examples 目录中找到。
开发指南
每个影响转换器或贪婪评估的更改都必须使用类 OnnxScriptTestCase
进行单元测试,以确保这两个系统使用相同的输入都返回相同的结果。
编码风格
我们使用 ruff
、black
、isort
和 mypy
等工具来检查代码格式,并使用 lintrunner
来运行所有代码检查器。您可以通过以下命令安装依赖项并初始化:
pip install lintrunner lintrunner-adapters
lintrunner init
这将安装 lintrunner 到您的系统,并下载所有必要的依赖项以在本地运行代码检查器。如果您想查看 lintrunner init 会安装什么,请运行 lintrunner init --dry-run
。
检查本地更改
lintrunner
格式化文件
lintrunner f
检查所有文件
lintrunner --all-files
使用 --output oneline
生成紧凑的代码检查错误列表,这在需要修复许多错误时非常有用。
使用 lintrunner -h
查看所有可用选项。
有关 lintrunner 的更多信息,请参阅 wiki。要更新现有的代码检查规则或创建一个新的规则,请修改 .lintrunner.toml
或根据 https://github.com/justinchuby/lintrunner-adapters 中的示例创建一个新的适配器。
贡献
我们始终欢迎您的帮助来改进产品(错误修复、新功能、文档等)。目前 ONNX 脚本处于早期和快速开发阶段,因此我们鼓励通过 提交问题 与团队首先讨论您的想法。
报告安全漏洞
请不要通过公共 GitHub 问题报告安全漏洞。
请参阅我们关于报告 安全漏洞 的指南。
许可指南
此项目欢迎贡献和建议。大多数贡献都需要您同意贡献者许可协议 (CLA),声明您有权利,并且确实授予我们使用您的贡献的权利。有关详细信息,请访问 https://cla.microsoft.com。
提交拉取请求时,CLA-bot 将自动确定您是否需要提供 CLA,并适当地标记 PR(例如,标签、注释)。只需遵循机器人提供的说明即可。您只需在整个使用我们的 CLA 的存储库中执行此操作一次。
行为准则
本项目采用了微软开源行为准则。更多信息请参阅行为准则常见问题解答,或通过opencode@microsoft.com联系我们,提出任何额外的问题或意见。
商标
本项目可能包含项目、产品或服务的商标或标志。微软商标或标志的授权使用必须遵守并遵循微软的商标与品牌指南。在修改版本项目中的微软商标或标志的使用不得造成混淆或暗示微软的赞助。任何第三方商标或标志的使用均受其相关政策约束。
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。
源代码分发
构建分发
哈希值 for onnxscript_preview-0.1.0.dev20230907-py3-none-any.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 8df14e156494c2dadb15661f0d9ffe264fe58ab1aae37f70a028f2807cd4b57c |
|
MD5 | 14745c23be0782f6099d9749ca778074 |
|
BLAKE2b-256 | 03a91ea9340669268ed343c0db817979ba51be0d5cc9a645cab0fce46def4874 |