跳转到主要内容

JAX的灵活模块

项目描述

PyPI

🥷 Ninjax: JAX的灵活模块

Ninjax是JAX的一个通用且实用的模块系统。它使用户能够全面且透明地控制每个模块的状态更新,为JAX带来灵活性,并启用新的用例。

概述

Ninjax提供了一个简单且通用的nj.Module类。

  • 模块可以存储诸如模型参数、Adam动量缓冲区、BatchNorm统计信息、循环状态等的状态。
  • 模块可以读写它们的状态条目。例如,这允许模块有训练方法,因为它们可以从内部更新它们的参数。
  • 任何方法都可以初始化、读取和写入状态条目。这避免了在Flax中使用特殊build()方法或@compact装饰器的需要。
  • Ninjax使得轻松混合和匹配来自不同库(如FlaxHaiku)的模块变得容易。
  • 与PyTrees不同,Ninjax状态是一个扁平的dict,将字符串键(如/net/layer1/weights)映射到jnp.array。这使得迭代、修改、保存或加载状态变得容易。
  • 模块可以使用dataclass语法指定类型化的超参数。

安装

Ninjax是一个单个文件,因此您可以将其复制到您的项目目录。或者,您可以安装该软件包

pip install ninjax

快速入门

import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax

Linear = nj.FromFlax(flax.linen.Dense)


class MyModel(nj.Module):

  lr: float = 1e-3

  def __init__(self, size):
    self.size = size
    # Define submodules upfront
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')
    self.opt = optax.adam(self.lr)

  def predict(self, x):
    x = jax.nn.relu(self.h1(x))
    x = jax.nn.relu(self.h2(x))
    # Define submodules inline
    x = self.sub('h3', Linear, self.size, use_bias=False)(x)
    # Create state entries inline
    x += self.value('bias', jnp.zeros, self.size)
    # Update state entries inline
    self.write('bias', self.read('bias') + 0.1)
    return x

  def loss(self, x, y):
    return ((self.predict(x) - y) ** 2).mean()

  def train(self, x, y):
    # Take grads wrt. to submodules or state keys
    wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
    loss, params, grads = nj.grad(self.loss, wrt)(x, y)
    # Update weights
    state = self.sub('optstate', nj.Tree, self.opt.init, params)
    updates, new_state = self.opt.update(grads, state.read(), params)
    params = optax.apply_updates(params, updates)
    nj.context().update(params)  # Store the new params
    state.write(new_state)       # Store new optimizer state
    return loss


# Create model and example data
model = MyModel(3, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)

# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])

# Purify for JAX transformations
train = jax.jit(nj.pure(model.train))

# Training loop
for x, y in [(x, y)] * 10:
  state, loss = train(state, x, y)
  print('Loss:', float(loss))

# Look at the parameters
print(state['model/bias'])

问题

如果您有问题,请提交问题

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分布

ninjax-3.5.1.tar.gz (17.1 kB 查看哈希值)

由以下提供支持