跳转到主要内容

JAX + Units

项目描述

JAX + Units

使用JAXPint构建!

此模块提供了JAX和Pint之间的接口,允许JAX支持具有单位的操作。单位传播发生在跟踪时间,因此jitted函数不应看到任何运行时成本。此库是实验性的,因此请期待一些尖锐的边缘。

例如

>>> import jax
>>> import jax.numpy as jnp
>>> import jpu
>>>
>>> u = jpu.UnitRegistry()
>>>
>>> @jax.jit
... def add_two_lengths(a, b):
...     return a + b
...
>>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)
<Quantity([3.045 3.012 3.039], 'meter')>

安装

要安装,请使用pip

python -m pip install jpu

唯一的依赖项是jaxpint,如果它们尚未在您的环境中,它们也将被安装。请参阅JAX文档以获取有关在不同系统上安装JAX的更多信息

用法

这是一个稍微更完整的示例

>>> import jax
>>> import numpy as np
>>> from jpu import UnitRegistry, numpy as jnpu
>>>
>>> u = UnitRegistry()
>>>
>>> @jax.jit
... def projectile_motion(v_init, theta, time, g=u.standard_gravity):
...     """Compute the motion of a projectile with support for units"""
...     x = v_init * time * jnpu.cos(theta)
...     y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time)
...     return x.to(u.m), y.to(u.m)
...
>>> x, y = projectile_motion(
...     5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s
... )

技术细节和限制

此库最显著的限制是,当与具有单位的“量”交互时,用户必须使用jpu.numpy函数,而不是使用jax.numpy接口。这是因为JAX尚未提供用于在自定义数组类上分派ufuncs的一般接口。我尝试过非文档化的__jax_array__接口,但它并不足够灵活,而且目前与Pytree对象不兼容。

到目前为止,只实现了numpy/jax.numpy接口的一个子集。欢迎提交拉取请求添加更广泛的支持(包括子模块)!

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

jpu-0.0.4.tar.gz (19.0 kB 查看哈希值)

上传时间

构建分布

jpu-0.0.4-py3-none-any.whl (15.9 kB 查看哈希值)

上传时间 Python 3

由以下组织支持