跳转到主要内容

未提供项目描述

项目描述

ml_dtypes

Unittests Wheel Build PyPI version

ml_dtypes是机器学习库中使用的几个NumPy数据类型扩展的独立实现,包括

  • bfloat16:标准float16格式的替代品
  • float8_*:包括以下几种实验性8位浮点表示
    • float8_e3m4
    • float8_e4m3
    • float8_e4m3b11fnuz
    • float8_e4m3fn
    • float8_e4m3fnuz
    • float8_e5m2
    • float8_e5m2fnuz
  • Microscaling (MX) 子字节浮点表示包括
    • float4_e2m1fn
    • float6_e2m3fn
    • float6_e3m2fn
  • int2int4uint2uint4:低精度整数类型。

以下为这些数字格式的说明。

安装

ml_dtypes 包在 Python 3.9-3.12 版本上进行了测试,可以使用以下命令进行安装

pip install ml_dtypes

要测试您的安装,您可以运行以下命令

pip install absl-py pytest
pytest --pyargs ml_dtypes

从源码构建,请克隆仓库并运行

git submodule init
git submodule update
pip install .

示例用法

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)

导入 ml_dtypes 也会将数据类型注册到 numpy 中,这样就可以通过它们的字符串名称来引用它们

>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)

实现浮点格式的规格说明

bfloat16

bfloat16 数字是一个16位截断的单精度浮点数。

指数:8,尾数:7,指数偏移量:127。IEEE 754,有 NaN 和 inf。

float4_e2m1fn

指数:2,尾数:1,偏移量:1。

扩展范围:没有 inf,没有 NaN。

微缩格式,4位(编码:0bSEEM),使用字节存储(最高4位未使用)。NaN 表示法未定义。

可能的绝对值:[00.511.52346]

float6_e2m3fn

指数:2,尾数:3,偏移量:1。

扩展范围:没有 inf,没有 NaN。

微缩格式,6位(编码:0bSEEMMM),使用字节存储(最高2位未使用)。NaN 表示法未定义。

可能的值范围:[-7.57.5]

float6_e3m2fn

指数:3,尾数:2,偏移量:3。

扩展范围:没有 inf,没有 NaN。

微缩格式,4位(编码:0bSEEEMM),使用字节存储(最高2位未使用)。NaN 表示法未定义。

可能的值范围:[-2828]

float8_e3m4

指数:3,尾数:4,偏移量:3。IEEE 754,有 NaN 和 inf。

float8_e4m3

指数:4,尾数:3,偏移量:7。IEEE 754,有 NaN 和 inf。

float8_e4m3b11fnuz

指数:4,尾数:3,偏移量:11。

扩展范围:没有 inf,NaN 由 0b1000'0000 表示。

float8_e4m3fn

指数:4,尾数:3,偏移量:7。

扩展范围:没有 inf,NaN 由 0bS111'1111 表示。

fn 后缀是为了与相应的 LLVM/MLIR 类型保持一致,表示此类型与 IEEE-754 不一致。f 表示它是有限值。n 表示它包括 NaN,但只在外部范围内。

float8_e4m3fnuz

3位尾数的8位浮点数。

一个带有1位符号位、4位指数和3位尾数的8位浮点类型。后缀 fnuz 与 LLVM/MLIR 命名一致,并来自与 IEEE 浮点约定的差异。《F》表示“有限”(没有无穷大),《N》表示具有特殊 NaN 编码,《UZ》表示无符号零。

此类型具有以下特性

  • 位编码:S1E4M3 - 0bSEEEEMMM
  • 指数偏移量:8
  • 不支持无穷大
  • NaN:支持,当符号位设置为 1,指数位和尾数位设置为全 0 时 - 0b10000000
  • 当指数为 0 时为非规格化数

float8_e5m2

指数:5,尾数:2,偏移量:15。IEEE 754,有 NaN 和 inf。

float8_e5m2fnuz

2位尾数的8位浮点数。

一个带有1位符号位、5位指数和2位尾数的8位浮点类型。后缀 fnuz 与 LLVM/MLIR 命名一致,并来自与 IEEE 浮点约定的差异。《F》表示“有限”(没有无穷大),《N》表示具有特殊 NaN 编码,《UZ》表示无符号零。

此类型具有以下特性

  • 位编码:S1E5M2 - 0bSEEEEEMM
  • 指数偏移量:16
  • 不支持无穷大
  • NaN:支持,当符号位设置为 1,指数位和尾数位设置为全 0 时 - 0b10000000
  • 当指数为 0 时为非规格化数

float8_e8m0fnu

OpenCompute MX 缩放格式 E8M0,具有以下属性

  • 无符号格式
  • 8位指数
  • 指数范围从 -127 到 127
  • 没有零和无穷大
  • 单个 NaN 值(0xFF)。

int2int4uint2uint4

2位和4位整数类型,其中每个元素以未打包的形式表示(即,填充到内存中的字节)。

NumPy 不支持小于单字节的类型:例如,数组中相邻元素之间的距离(.strides)以字节数的整数表示。放宽此限制将是一项重大的工程项目。因此,这些类型使用非打包表示法,其中数组的每个元素在内存中填充至一个字节。每个字节的两个或四个低位包含数字的表示,而其余的高位被忽略。

低精度算术的怪癖

如果您正在探索在代码中使用低精度 dtypes,您应小心预测精度损失可能导致意外结果的情况。一个例子是聚合函数 sum 的行为;考虑以下 NumPy 中的 bfloat16 求和(使用版本 1.24.2 运行)

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256

真实总和应接近 5000,但 numpy 返回的是 exactly 256:这是因为 bfloat16 没有足够的精度来通过小于 1 的值增加 256

>>> bfloat16(256) + bfloat16(1)
256

256 之后,bfloat16 的下一个可表示的值是 258

>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258

为了获得更好的结果,您可以指定累加应使用更高精度的类型,如 float32

>>> vals.sum(dtype='float32').astype(bfloat16)
4992

与 NumPy 不同,像 JAX 这样的项目,它们更原生地支持低精度算术,通常会自动执行这类高精度累加

>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)

许可

这不是一个官方支持的 Google 产品。

ml_dtypes 源代码根据 Apache 2.0 许可证授权(见 LICENSE)。预编译的 wheel 文件使用 EIGEN 项目构建,该项目根据 MPL 2.0 许可证发布(见 LICENSE.eigen)。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关 安装包 的更多信息。

源代码分发

ml_dtypes-0.5.0.tar.gz (699.4 kB 查看哈希值)

上传时间 源代码

构建的分发

ml_dtypes-0.5.0-cp313-cp313-win_amd64.whl (213.2 kB 查看哈希值)

上传时间 CPython 3.13 Windows x86-64

ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB 查看哈希值)

上传时间 CPython 3.13 manylinux: glibc 2.17+ x86-64

ml_dtypes-0.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (4.4 MB 查看哈希值)

上传时间 CPython 3.13 manylinux: glibc 2.17+ ARM64

ml_dtypes-0.5.0-cp313-cp313-macosx_10_13_universal2.whl (753.3 kB 查看哈希值)

上传时间 CPython 3.13 macOS 10.13+ universal2 (ARM64, x86-64)

ml_dtypes-0.5.0-cp312-cp312-win_amd64.whl (213.2 kB 查看哈希值)

上传时间 CPython 3.12 Windows x86-64

ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB 查看哈希值)

上传时间 CPython 3.12 manylinux: glibc 2.17+ x86-64

ml_dtypes-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (4.4 MB 查看哈希值)

上传时间 CPython 3.12 manylinux: glibc 2.17+ ARM64

ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl (750.2 kB 查看哈希值)

上传时间 CPython 3.12 macOS 10.9+ universal2 (ARM64, x86-64)

ml_dtypes-0.5.0-cp311-cp311-win_amd64.whl (211.9 kB 查看哈希值)

上传时间 CPython 3.11 Windows x86-64

ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.17+ x86-64

ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (4.4 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.17+ ARM64

ml_dtypes-0.5.0-cp311-cp311-macosx_10_9_universal2.whl (736.8 kB 查看哈希值)

上传于 CPython 3.11 macOS 10.9+ universal2 (ARM64, x86-64)

ml_dtypes-0.5.0-cp310-cp310-win_amd64.whl (211.9 kB 查看哈希值)

上传于 CPython 3.10 Windows x86-64

ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB 查看哈希值)

上传于 CPython 3.10 manylinux: glibc 2.17+ x86-64

ml_dtypes-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (4.4 MB 查看哈希值)

上传于 CPython 3.10 manylinux: glibc 2.17+ ARM64

ml_dtypes-0.5.0-cp310-cp310-macosx_10_9_universal2.whl (736.8 kB 查看哈希值)

上传于 CPython 3.10 macOS 10.9+ universal2 (ARM64, x86-64)

ml_dtypes-0.5.0-cp39-cp39-win_amd64.whl (211.3 kB 查看哈希值)

上传于 CPython 3.9 Windows x86-64

ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB 查看哈希值)

上传于 CPython 3.9 manylinux: glibc 2.17+ x86-64

ml_dtypes-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (4.4 MB 查看哈希值)

上传于 CPython 3.9 manylinux: glibc 2.17+ ARM64

ml_dtypes-0.5.0-cp39-cp39-macosx_10_9_universal2.whl (732.2 kB 查看哈希值)

上传于 CPython 3.9 macOS 10.9+ universal2 (ARM64, x86-64)

由以下支持