未提供项目描述
项目描述
JAX AI Stack
JAX 是一个用于数组计算的Python包和程序转换。围绕它的是一个不断增长的专门数值计算包生态系统;这些项目的最新列表可以在 Awesome JAX 中找到。
尽管JAX经常与pytorch等神经网络库进行比较,但JAX核心包本身包含的与神经网络模型特定相关的功能非常少。相反,JAX鼓励模块化,即特定领域的库是从核心包中独立开发的:这有助于推动创新,因为研究人员和其他用户探索可能实现的内容。
在这个更大的分布式生态系统中,有一些项目是谷歌研究人员和工程师在实现和部署如 Imagen、Gemini 等生成式AI工具背后的模型时发现的非常有用。JAX AI栈是这个库套件的单一入口点,因此您可以安装并开始使用许多与谷歌开发者日常工作中使用的相同开源包。
要开始使用JAX AI堆栈,您可以查看使用JAX入门。这仍然是一个正在进行中的项目,请在未来几周内查看更多文档和教程!
安装堆栈
可以使用以下命令安装堆栈
pip install jax-ai-stack
此命令固定了特定版本的项目组件,这些版本通过此存储库中的集成测试被证明可以正确协同工作。包包括
- JAX:核心JAX包,包括数组操作和程序转换,如
jit
、vmap
、grad
等。 - flax:使用JAX构建神经网络
- ml_dtypes:机器学习的NumPy数据类型扩展
- optax:JAX中的梯度处理和优化
- orbax:JAX的检查点和持久化工具
可选包
此外,您还可以使用pip
的额外功能安装可选包。以下命令
pip install jax-ai-stack[grain]
将安装兼容版本的grain数据加载器(目前仅限linux)。
同样,以下命令
pip install jax-ai-stack[tfds]
将安装兼容版本的tensorflow和tensorflow-datasets。
项目详情
下载文件
下载适用于您的平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。
源分布
jax_ai_stack-2024.10.1.tar.gz (8.2 kB 查看哈希值)
构建分布
jax_ai_stack-2024.10.1-py3-none-any.whl (11.0 kB 查看哈希值)