跳转到主要内容

使用Python子集自然地编写ONNX函数和模型

项目描述

ONNX Script

CI Dev Release PyPI - Version PyPI - Python Version Ruff Black

ONNX Script允许开发者使用Python子集自然地编写ONNX函数和模型。ONNX Script是

  • 表现力强: 允许编写所有ONNX函数。
  • 简单简洁: 函数代码自然简单。
  • 可调试: 允许进行即时模式评估,从而提供更愉悦的ONNX模型调试体验。

此存储库还涵盖

  • ONNX IR: 一种内存中的IR,支持完整的ONNX规范,用于图构建、分析和转换。
  • ONNX Script优化器: 提供功能,通过执行优化和清理(如常量折叠、死代码消除等)来优化ONNX模型。
  • ONNX重写器: 提供功能,根据用户定义的重写规则,用替换模式替换ONNX图中的某些模式。

请注意,ONNX Script 旨在支持Python语言的全部功能。

网站: https://onnxscript.ai/

设计概述

ONNX Script为编写和调试ONNX模型和函数提供了一些主要功能

  • 一种将 Python ONNX Script 函数转换为 ONNX 图的转换器,通过遍历 Python 抽象语法树 来构建函数的等效 ONNX 图。

  • 一种逆向操作的转换器,将 ONNX 模型和函数转换为 ONNX Script。这种功能可以用来实现 ONNX Script 与 ONNX 图的完全往返。

  • 一个运行时垫片,允许这些函数以(“急切模式”)进行评估。此功能目前依赖于 ONNX Runtime 来执行每个 ONNX 操作符,并且正在进行一个仅适用于 Python 的 ONNX 参考运行时,也将得到支持。

    请注意,运行时旨在帮助理解和调试函数定义。性能不是目标。

安装 ONNX Script

pip install --upgrade onnxscript

用于开发安装

git clone https://github.com/microsoft/onnxscript
cd onnxscript
pip install -r requirements-dev.txt
pip install -e .

运行单元测试

pytest .

示例

import onnx

# We use ONNX opset 15 to define the function below.
from onnxscript import FLOAT, script
from onnxscript import opset15 as op


# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def onnx_hardmax(X, axis: int):
    """Hardmax is similar to ArgMax, with the result being encoded OneHot style."""

    # The type annotation on X indicates that it is a float tensor of
    # unknown rank. The type annotation on axis indicates that it will
    # be treated as an int attribute in ONNX.
    #
    # Invoke ONNX opset 15 op ArgMax.
    # Use unnamed arguments for ONNX input parameters, and named
    # arguments for ONNX attribute parameters.
    argmax = op.ArgMax(X, axis=axis, keepdims=False)
    xshape = op.Shape(X, start=axis)
    # use the Constant operator to create constant tensors
    zero = op.Constant(value_ints=[0])
    depth = op.GatherElements(xshape, zero)
    empty_shape = op.Constant(value_ints=[0])
    depth = op.Reshape(depth, empty_shape)
    values = op.Constant(value_ints=[0, 1])
    cast_values = op.CastLike(values, X)
    return op.OneHot(argmax, depth, cast_values, axis=axis)


# We use the script decorator to indicate that
# this is meant to be translated to ONNX.
@script()
def sample_model(X: FLOAT[64, 128], Wt: FLOAT[128, 10], Bias: FLOAT[10]) -> FLOAT[64, 10]:
    matmul = op.MatMul(X, Wt) + Bias
    return onnx_hardmax(matmul, axis=1)


# onnx_model is an in-memory ModelProto
onnx_model = sample_model.to_model_proto()

# Save the ONNX model at a given path
onnx.save(onnx_model, "sample_model.onnx")

# Check the model
try:
    onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
    print(f"The model is invalid: {e}")
else:
    print("The model is valid!")

装饰器解析函数的代码,将其转换为中间表示。如果失败,则生成错误消息,指示错误检测到的行。如果成功,中间表示可以转换为类型为 FunctionProto 的 ONNX 图结构

  • Hardmax.to_function_proto() 返回一个 FunctionProto

急切模式评估

急切模式主要用于调试和验证中间结果是否符合预期。上面定义的函数可以如下调用,以急切评估模式执行

import numpy as np

v = np.array([[0, 1], [2, 3]], dtype=np.float32)
result = Hardmax(v)

更多示例可以在 docs/examples 目录中找到。

ONNX IR

一个内存中的 IR,支持完整的 ONNX 规范,设计用于图构建、分析和转换。

功能

  • 完全支持 ONNX 规范:所有有效的模型都可以通过 ONNX protobuf 表示,以及一部分无效的模型(您可以加载并修复它们)。
  • 内存占用低:使用 mmap 的外部张量;统一的 ONNX TensorProto、Numpy 数组和 PyTorch Tensors 等接口。没有张量大小限制。零复制。
  • 简单的访问模式:轻松访问值信息和遍历图拓扑。
  • 健壮的修改:在修改图的同时,可以创建尽可能多的迭代器。
  • 速度:高效的图操作,序列化和反序列化到 Protobuf。
  • Pythonic 和熟悉的 API:类定义了 Pythonic api,并且仍然以直观的方式映射到 ONNX protobuf 概念。

ONNX Script 工具

ONNX 优化器

ONNX Script 优化器工具为用户提供通过执行优化和清理(如常数折叠、死代码消除等)来优化 ONNX 模型的功能。为了使用优化器工具

import onnxscript

onnxscript.optimizer.optimize(onnx_model)

有关优化器应用的优化细节的详细总结,请参阅教程 使用优化器优化模型

ONNX 重写器

ONNX 重写器工具为用户提供替换 ONNX 图中某些模式的功能,基于用户定义的重写规则。重写器工具允许两种不同的方法来重写图中的模式。

基于模式的重写

对于这种重写风格,用户提供要替换的 target_pattern,一个 replacement_pattern 和一个 match_condition(只有满足匹配条件时,模式重写才会发生)。以下是如何使用基于模式的重写工具的一个简单示例

from onnxscript.rewriter import pattern

# The target pattern
def erf_gelu_pattern(op, x):
    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

def erf_gelu_pattern_2(op, x):
    return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5

# The replacement pattern
def gelu(op, x: ir.Value):
    return op.Gelu(x, domain="com.microsoft")

# Create multiple rules
rule1 = pattern.RewriteRule(
    erf_gelu_pattern,  # Target Pattern
    gelu,  # Replacement
)
rule2 = pattern.RewriteRule(
    erf_gelu_pattern_2,  # Target Pattern
    gelu,  # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
    model,  # Original ONNX Model
    pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied

有关如何创建 target_pattern、replacement_pattern 和 match_condition 块的详细教程,请参阅教程 使用规则进行基于模式的重写

基于函数的重写

这种重写风格将用户提供的 FUNCTION_KEYWORDPACKAGE_NAME 与图中现有函数进行匹配,并用用户提供的新的函数替换它。

开发指南

所有影响转换器或贪婪求值的变化都必须使用类 OnnxScriptTestCase 进行单元测试,以确保两个系统在相同的输入下返回相同的结果。

编码风格

我们使用 ruffblackisortmypy 等工具来检查代码格式,并使用 lintrunner 运行所有代码检查器。您可以通过以下命令安装依赖项并初始化:

pip install lintrunner lintrunner-adapters
lintrunner init

这将安装 lintrunner 到您的系统并下载所有必要的依赖项以本地运行代码检查器。如果您想查看 lintrunner init 将安装的内容,请运行 lintrunner init --dry-run

要检查本地更改

lintrunner

要格式化文件

lintrunner f

要检查所有文件

lintrunner --all-files

使用 --output oneline 生成紧凑的错误列表,当有很多错误需要修复时非常有用。

使用 lintrunner -h 查看所有可用选项。

要了解更多关于 lintrunner 的信息,请参阅 wiki。要更新现有的代码检查规则或创建一个新的规则,请修改 .lintrunner.toml 或根据 https://github.com/justinchuby/lintrunner-adapters 中的示例创建一个新的适配器。

贡献

我们始终欢迎您的帮助来改进产品(错误修复、新功能、文档等)。目前 ONNX Script 处于早期和快速开发阶段,因此我们鼓励通过 提交问题 向团队提出任何重大更改,首先讨论您的想法。

报告安全漏洞

请不要通过公开的 GitHub 问题报告安全漏洞。

请参阅我们关于提交 安全漏洞 的指南。

许可指南

本项目欢迎贡献和建议。大多数贡献都需要您同意一份贡献者许可协议(CLA),声明您有权,并且实际上确实授予我们使用您的贡献的权利。有关详细信息,请访问 https://cla.microsoft.com

当您提交拉取请求时,CLA-bot 将自动确定您是否需要提供 CLA,并适当地装饰 PR(例如,标签、注释)。只需遵循机器人提供的说明即可。您只需在整个使用我们的 CLA 的存储库中执行此操作一次。

行为准则

本项目采用 Microsoft Open Source Code of Conduct。有关更多信息,请参阅 行为准则常见问题解答 或联系 opencode@microsoft.com 了解任何其他问题或意见。

商标

本项目可能包含项目、产品或服务的商标或徽标。Microsoft 商标或徽标的授权使用必须遵循并遵守 Microsoft 的商标 & 品牌指南。在此项目的修改版本中使用 Microsoft 商标或徽标不得造成混淆或暗示 Microsoft 赞助。任何第三方商标或徽标的用途均受那些第三方政策的约束。

项目详情


发行历史 发行通知 | RSS 源

下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源代码发行版

onnxscript-0.1.0.dev20241004.tar.gz (557.6 kB 查看哈希值)

上传时间 源代码

构建发行版

onnxscript-0.1.0.dev20241004-py3-none-any.whl (670.5 kB 查看哈希值)

上传时间 Python 3

由以下组织支持