跳转到主要内容

Objax是一个机器学习框架,为JAX提供面向对象的层。

项目描述

Objax

教程 | 安装 | 文档 | 哲学

这不是一个官方支持的谷歌产品。

Objax 是一个开源机器学习框架,它通过最小化的面向对象设计和可读的代码库加速研究和学习。其名称来自“对象”和 JAX(一个流行的性能框架)的缩写。Objax 是由研究人员为研究人员设计的,注重简洁性和可理解性。用户应该能够轻松阅读、理解、扩展和修改它以满足自己的需求。

这是 Objax 的开发者仓库,这里几乎没有用户文档,完整的文档请访问 objax.readthedocs.io

您可以在项目的子目录中找到 README 文件,例如

用户安装指南

您可以使用以下方式使用 pip 安装 Objax

pip install --upgrade objax

Objax 支持 GPU,但假设您已经安装了某种版本的 CUDA。以下是安装启用 CUDA 的 jaxlib(jaxlib 版本需要 CUDA 11.2 或更高版本)所需的额外步骤

RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
pip uninstall -y jaxlib
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION

有关更多安装选项,请参阅 https://github.com/google/jax#pip-installation-gpu-cuda

有用的环境配置

以下是一些有用的选项

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false

测试您的安装

您可以通过运行以下代码来测试您的安装

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

通常,如果您在运行此代码时使用 CUDA 出现错误,这通常意味着您的 CUDA 或 CuDNN 安装存在问题。

运行代码示例

克隆代码仓库

git clone https://github.com/google/objax.git
cd objax/examples

引用 Objax

要引用此仓库

@software{objax2020github,
  author = {{Objax Developers}},
  title = {{Objax}},
  url = {https://github.com/google/objax},
  version = {1.2.0},
  year = {2020},
}

开发者文档

以下是有关 开发设置添加新代码的指南 的信息。

项目详情


下载文件

下载适用于您的平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源代码分布

objax-1.8.0.tar.gz (59.9 kB 查看哈希值)

上传 源代码

构建分布

objax-1.8.0-py3-none-any.whl (86.7 kB 查看哈希值)

上传 Python 3

支持