JAX中的数量
项目描述
unxt
JAX中的有量纲数量
Unxt是JAX中的单位量和计算,基于JAX、Equinox和Quax构建。
是的,它支持自动微分(grad
、jacobian
、hessian
)和向量化(vmap
等)。
安装
pip install unxt
文档
快速示例
from unxt import Quantity
x = Quantity(jnp.arange(1, 5, dtype=float), "kpc")
print(x)
# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='kpc')
# Addition / Subtraction
print(x + x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')
# Multiplication / Division
print(2 * x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='kpc')
y = Quantity(jnp.arange(4, 8, dtype=float), "Gyr")
print(x / y)
# Quantity['speed'](Array([0.25 , 0.4 , 0.5 , 0.57142857], dtype=float64), unit='kpc / Gyr')
# Exponentiation
print(x**2)
# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='kpc2')
# Unit Checking on operations
try:
x + y
except Exception as e:
print(e)
# 'Gyr' (time) and 'kpc' (length) are not convertible
unxt
基于quax
构建,它使得JAX中可以存在自定义的类似数组的对象。为了方便起见,我们使用了quaxed
库,它只是对jax
的quax.quaxify
包装,以避免不必要的代码。
from quaxed import grad, vmap
import quaxed.numpy as jnp
print(jnp.square(x))
# Quantity['area'](Array([ 1., 4., 9., 16.], dtype=float64), unit='kpc2')
print(qnp.power(x, 3))
# Quantity['volume'](Array([ 1., 8., 27., 64.], dtype=float64), unit='kpc3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='kpc2')
由于Quantity
是参数化的,它可以进行运行时维度检查!
LengthQuantity = Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.
引用
如果您认为这个库很有用,并希望支持科学界低级代码库的开发和维护,请考虑引用这项工作。
开发
我们欢迎贡献!
项目详情
下载文件
下载适用于您平台的文件。如果您不确定该选择哪一个,请了解更多关于安装软件包的信息。
源分布
unxt-0.17.0.tar.gz (540.3 kB 查看散列值)
构建分布
unxt-0.17.0-py3-none-any.whl (50.0 kB 查看散列值)