JAX + Units
项目描述
JAX + Units
此模块提供了JAX和Pint之间的接口,允许JAX支持具有单位的操作。单位传播发生在跟踪时间,因此jitted函数不应看到任何运行时成本。此库是实验性的,因此请期待一些尖锐的边缘。
例如
>>> import jax
>>> import jax.numpy as jnp
>>> import jpu
>>>
>>> u = jpu.UnitRegistry()
>>>
>>> @jax.jit
... def add_two_lengths(a, b):
... return a + b
...
>>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)
<Quantity([3.045 3.012 3.039], 'meter')>
安装
要安装,请使用pip
python -m pip install jpu
唯一的依赖项是jax
和pint
,如果它们尚未在您的环境中,它们也将被安装。请参阅JAX文档以获取有关在不同系统上安装JAX的更多信息。
用法
这是一个稍微更完整的示例
>>> import jax
>>> import numpy as np
>>> from jpu import UnitRegistry, numpy as jnpu
>>>
>>> u = UnitRegistry()
>>>
>>> @jax.jit
... def projectile_motion(v_init, theta, time, g=u.standard_gravity):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu.cos(theta)
... y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time)
... return x.to(u.m), y.to(u.m)
...
>>> x, y = projectile_motion(
... 5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s
... )
技术细节和限制
此库最显著的限制是,当与具有单位的“量”交互时,用户必须使用jpu.numpy
函数,而不是使用jax.numpy
接口。这是因为JAX尚未提供用于在自定义数组类上分派ufuncs的一般接口。我尝试过非文档化的__jax_array__
接口,但它并不足够灵活,而且目前与Pytree对象不兼容。
到目前为止,只实现了numpy
/jax.numpy
接口的一个子集。欢迎提交拉取请求添加更广泛的支持(包括子模块)!
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。
源分布
jpu-0.0.4.tar.gz (19.0 kB 查看哈希值)
构建分布
jpu-0.0.4-py3-none-any.whl (15.9 kB 查看哈希值)
关闭
jpu-0.0.4.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 03314c23504bec25bf95142d3684da5d1e962e2ead7067cd96819f5856bb6d57 |
|
MD5 | 28ef83e5e67571bc5baf18bfeada07ce |
|
BLAKE2b-256 | 7214e28417860c57092f62ff4bd56e5d3f9284ce9488f2f8893fa0db0da7ec50 |
关闭
jpu-0.0.4-py3-none-any.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 805669c2544130863f90b3031ac62d9d1eaa5081dc349c055799cc6e638750a9 |
|
MD5 | eb3c06b8ea3c2fcf4f99f78ed069e379 |
|
BLAKE2b-256 | c4488cf91fc33e11b340fc21e0bcadee1c7cfb4113a275bc972dde1a2a28eb76 |