使用Python子集自然地编写ONNX函数和模型
项目描述
ONNX Script
ONNX Script允许开发者使用Python子集自然地编写ONNX函数和模型。ONNX Script是
- 表现力强: 允许编写所有ONNX函数。
- 简单简洁: 函数代码自然简单。
- 可调试: 允许进行即时模式评估,从而提供更愉悦的ONNX模型调试体验。
此存储库还涵盖
- ONNX IR: 一种内存中的IR,支持完整的ONNX规范,用于图构建、分析和转换。
- ONNX Script优化器: 提供功能,通过执行优化和清理(如常量折叠、死代码消除等)来优化ONNX模型。
- ONNX重写器: 提供功能,根据用户定义的重写规则,用替换模式替换ONNX图中的某些模式。
请注意,ONNX Script 不旨在支持Python语言的全部功能。
设计概述
ONNX Script为编写和调试ONNX模型和函数提供了一些主要功能
-
一种将 Python ONNX Script 函数转换为 ONNX 图的转换器,通过遍历 Python 抽象语法树 来构建函数的等效 ONNX 图。
-
一种逆向操作的转换器,将 ONNX 模型和函数转换为 ONNX Script。这种功能可以用来实现 ONNX Script 与 ONNX 图的完全往返。
-
一个运行时垫片,允许这些函数以(“急切模式”)进行评估。此功能目前依赖于 ONNX Runtime 来执行每个 ONNX 操作符,并且正在进行一个仅适用于 Python 的 ONNX 参考运行时,也将得到支持。
请注意,运行时旨在帮助理解和调试函数定义。性能不是目标。
安装 ONNX Script
pip install --upgrade onnxscript
用于开发安装
git clone https://github.com/microsoft/onnxscript
cd onnxscript
pip install -r requirements-dev.txt
pip install -e .
运行单元测试
pytest .
示例
import onnx
# We use ONNX opset 15 to define the function below.
from onnxscript import FLOAT, script
from onnxscript import opset15 as op
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def onnx_hardmax(X, axis: int):
"""Hardmax is similar to ArgMax, with the result being encoded OneHot style."""
# The type annotation on X indicates that it is a float tensor of
# unknown rank. The type annotation on axis indicates that it will
# be treated as an int attribute in ONNX.
#
# Invoke ONNX opset 15 op ArgMax.
# Use unnamed arguments for ONNX input parameters, and named
# arguments for ONNX attribute parameters.
argmax = op.ArgMax(X, axis=axis, keepdims=False)
xshape = op.Shape(X, start=axis)
# use the Constant operator to create constant tensors
zero = op.Constant(value_ints=[0])
depth = op.GatherElements(xshape, zero)
empty_shape = op.Constant(value_ints=[0])
depth = op.Reshape(depth, empty_shape)
values = op.Constant(value_ints=[0, 1])
cast_values = op.CastLike(values, X)
return op.OneHot(argmax, depth, cast_values, axis=axis)
# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
matmul = op.MatMul(X, Wt) + Bias
return onnx_hardmax(matmul, axis=1)
# onnx_model is an in-memory ModelProto
onnx_model = sample_model.to_model_proto()
# Save the ONNX model at a given path
onnx.save(onnx_model, "sample_model.onnx")
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
print(f"The model is invalid: {e}")
else:
print("The model is valid!")
装饰器解析函数的代码,将其转换为中间表示。如果失败,则生成错误消息,指示错误检测到的行。如果成功,中间表示可以转换为类型为 FunctionProto
的 ONNX 图结构
Hardmax.to_function_proto()
返回一个FunctionProto
急切模式评估
急切模式主要用于调试和验证中间结果是否符合预期。上面定义的函数可以如下调用,以急切评估模式执行
import numpy as np
v = np.array([[0, 1], [2, 3]], dtype=np.float32)
result = Hardmax(v)
更多示例可以在 docs/examples 目录中找到。
ONNX IR
一个内存中的 IR,支持完整的 ONNX 规范,设计用于图构建、分析和转换。
功能
- 完全支持 ONNX 规范:所有有效的模型都可以通过 ONNX protobuf 表示,以及一部分无效的模型(您可以加载并修复它们)。
- 内存占用低:使用 mmap 的外部张量;统一的 ONNX TensorProto、Numpy 数组和 PyTorch Tensors 等接口。没有张量大小限制。零复制。
- 简单的访问模式:轻松访问值信息和遍历图拓扑。
- 健壮的修改:在修改图的同时,可以创建尽可能多的迭代器。
- 速度:高效的图操作,序列化和反序列化到 Protobuf。
- Pythonic 和熟悉的 API:类定义了 Pythonic api,并且仍然以直观的方式映射到 ONNX protobuf 概念。
ONNX Script 工具
ONNX 优化器
ONNX Script 优化器工具为用户提供通过执行优化和清理(如常数折叠、死代码消除等)来优化 ONNX 模型的功能。为了使用优化器工具
import onnxscript
onnxscript.optimizer.optimize(onnx_model)
有关优化器应用的优化细节的详细总结,请参阅教程 使用优化器优化模型
ONNX 重写器
ONNX 重写器工具为用户提供替换 ONNX 图中某些模式的功能,基于用户定义的重写规则。重写器工具允许两种不同的方法来重写图中的模式。
基于模式的重写
对于这种重写风格,用户提供要替换的 target_pattern
,一个 replacement_pattern
和一个 match_condition
(只有满足匹配条件时,模式重写才会发生)。以下是如何使用基于模式的重写工具的一个简单示例
from onnxscript.rewriter import pattern
# The target pattern
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
def erf_gelu_pattern_2(op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
# The replacement pattern
def gelu(op, x: ir.Value):
return op.Gelu(x, domain="com.microsoft")
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model, # Original ONNX Model
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied
有关如何创建 target_pattern、replacement_pattern 和 match_condition 块的详细教程,请参阅教程 使用规则进行基于模式的重写
基于函数的重写
这种重写风格将用户提供的 FUNCTION_KEYWORD
和 PACKAGE_NAME
与图中现有函数进行匹配,并用用户提供的新的函数替换它。
开发指南
所有影响转换器或贪婪求值的变化都必须使用类 OnnxScriptTestCase
进行单元测试,以确保两个系统在相同的输入下返回相同的结果。
编码风格
我们使用 ruff
、black
、isort
和 mypy
等工具来检查代码格式,并使用 lintrunner
运行所有代码检查器。您可以通过以下命令安装依赖项并初始化:
pip install lintrunner lintrunner-adapters
lintrunner init
这将安装 lintrunner 到您的系统并下载所有必要的依赖项以本地运行代码检查器。如果您想查看 lintrunner init 将安装的内容,请运行 lintrunner init --dry-run
。
要检查本地更改
lintrunner
要格式化文件
lintrunner f
要检查所有文件
lintrunner --all-files
使用 --output oneline
生成紧凑的错误列表,当有很多错误需要修复时非常有用。
使用 lintrunner -h
查看所有可用选项。
要了解更多关于 lintrunner 的信息,请参阅 wiki。要更新现有的代码检查规则或创建一个新的规则,请修改 .lintrunner.toml
或根据 https://github.com/justinchuby/lintrunner-adapters 中的示例创建一个新的适配器。
贡献
我们始终欢迎您的帮助来改进产品(错误修复、新功能、文档等)。目前 ONNX Script 处于早期和快速开发阶段,因此我们鼓励通过 提交问题 向团队提出任何重大更改,首先讨论您的想法。
报告安全漏洞
请不要通过公开的 GitHub 问题报告安全漏洞。
请参阅我们关于提交 安全漏洞 的指南。
许可指南
本项目欢迎贡献和建议。大多数贡献都需要您同意一份贡献者许可协议(CLA),声明您有权,并且实际上确实授予我们使用您的贡献的权利。有关详细信息,请访问 https://cla.microsoft.com。
当您提交拉取请求时,CLA-bot 将自动确定您是否需要提供 CLA,并适当地装饰 PR(例如,标签、注释)。只需遵循机器人提供的说明即可。您只需在整个使用我们的 CLA 的存储库中执行此操作一次。
行为准则
本项目采用 Microsoft Open Source Code of Conduct。有关更多信息,请参阅 行为准则常见问题解答 或联系 opencode@microsoft.com 了解任何其他问题或意见。
商标
本项目可能包含项目、产品或服务的商标或徽标。Microsoft 商标或徽标的授权使用必须遵循并遵守 Microsoft 的商标 & 品牌指南。在此项目的修改版本中使用 Microsoft 商标或徽标不得造成混淆或暗示 Microsoft 赞助。任何第三方商标或徽标的用途均受那些第三方政策的约束。
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。
源代码发行版
构建发行版
onnxscript-0.1.0.dev20241004.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 6efdb1e174bf1ac3e702ecc57a38119a29772aed48446903d9ec05e458d3e947 |
|
MD5 | e1ed85ca943863dc4678b3b9c58c392f |
|
BLAKE2b-256 | ec826638546ea028b611477279b5002449785d4fb125272f017c6c6f473195af |
onnxscript-0.1.0.dev20241004-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | d5aa34d2a8ff2b3fe5bac680dbde4bcbbaddc8b0d15784081f31d4cee32eff54 |
|
MD5 | a690c27a5f2d1f4cf3176699f35f6ed1 |
|
BLAKE2b-256 | be85a5b535a1e981fb154bd09078412703d943e04be86f8a2590e493e5d26f44 |