Jax或Numpy的共同后端。
项目描述
- 如果已安装Jax并且提供了Jax输入,则运行
jax.numpy
函数 - 如果已安装Jax并且函数被jitted,则运行
jax.numpy
函数 - 否则,jumpy函数返回NumPy输出
有几个函数(例如vmap
、scan
)在安装了jax
的情况下可用。
Jumpy允许您编写与框架无关的代码,通过作为原始Numpy运行易于调试,但在jitted时与JAX一样高效。
我们主要维护这个仓库是为了能够编写适用于标准NumPy或基于Jax硬件加速环境的Gymnasium和PettingZoo包装器,但这个包可以用在很多其他的事情上。
安装Jumpy
要从PyPI安装Jumpy,使用pip install jax-jumpy[jax]
将包括Jax,而使用pip install jax-jumpy
则不会包括Jax。
或者,要从源安装Jumpy,请克隆此仓库,进入目录,然后:pip install .
贡献
Jumpy没有实现所有numpy
或jax.numpy
函数。如果您缺少某些函数,请创建一个issue或pull request,我们将很乐意添加它们。
将来,我们希望添加对PyTorch的可选支持,并寻求pull request来完成此功能。
项目详情
下载文件
下载适用于您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。
源分布
jax-jumpy-1.0.0.tar.gz (19.4 kB 查看哈希值)
构建分布
jax_jumpy-1.0.0-py3-none-any.whl (20.4 kB 查看哈希值)