JAX的灵活模块
项目描述
🥷 Ninjax: JAX的灵活模块
Ninjax是JAX的一个通用且实用的模块系统。它使用户能够全面且透明地控制每个模块的状态更新,为JAX带来灵活性,并启用新的用例。
概述
Ninjax提供了一个简单且通用的nj.Module
类。
- 模块可以存储诸如模型参数、Adam动量缓冲区、BatchNorm统计信息、循环状态等的状态。
- 模块可以读写它们的状态条目。例如,这允许模块有训练方法,因为它们可以从内部更新它们的参数。
- 任何方法都可以初始化、读取和写入状态条目。这避免了在Flax中使用特殊
build()
方法或@compact
装饰器的需要。 - Ninjax使得轻松混合和匹配来自不同库(如Flax和Haiku)的模块变得容易。
- 与PyTrees不同,Ninjax状态是一个扁平的
dict
,将字符串键(如/net/layer1/weights
)映射到jnp.array
。这使得迭代、修改、保存或加载状态变得容易。 - 模块可以使用dataclass语法指定类型化的超参数。
安装
Ninjax是一个单个文件,因此您可以将其复制到您的项目目录。或者,您可以安装该软件包
pip install ninjax
快速入门
import flax
import jax
import jax.numpy as jnp
import ninjax as nj
import optax
Linear = nj.FromFlax(flax.linen.Dense)
class MyModel(nj.Module):
lr: float = 1e-3
def __init__(self, size):
self.size = size
# Define submodules upfront
self.h1 = Linear(128, name='h1')
self.h2 = Linear(128, name='h2')
self.opt = optax.adam(self.lr)
def predict(self, x):
x = jax.nn.relu(self.h1(x))
x = jax.nn.relu(self.h2(x))
# Define submodules inline
x = self.sub('h3', Linear, self.size, use_bias=False)(x)
# Create state entries inline
x += self.value('bias', jnp.zeros, self.size)
# Update state entries inline
self.write('bias', self.read('bias') + 0.1)
return x
def loss(self, x, y):
return ((self.predict(x) - y) ** 2).mean()
def train(self, x, y):
# Take grads wrt. to submodules or state keys
wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']
loss, params, grads = nj.grad(self.loss, wrt)(x, y)
# Update weights
state = self.sub('optstate', nj.Tree, self.opt.init, params)
updates, new_state = self.opt.update(grads, state.read(), params)
params = optax.apply_updates(params, updates)
nj.context().update(params) # Store the new params
state.write(new_state) # Store new optimizer state
return loss
# Create model and example data
model = MyModel(3, name='model')
x = jnp.ones((64, 32), jnp.float32)
y = jnp.ones((64, 3), jnp.float32)
# Populate initial state from one or more functions
state = {}
state = nj.init(model.train)(state, x, y, seed=0)
print(state['model/bias'])
# Purify for JAX transformations
train = jax.jit(nj.pure(model.train))
# Training loop
for x, y in [(x, y)] * 10:
state, loss = train(state, x, y)
print('Loss:', float(loss))
# Look at the parameters
print(state['model/bias'])
问题
如果您有问题,请提交问题。
项目详情
关闭
ninjax-3.5.1.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | e709dc6aefb71712c1527cd6456fda8afdf235c54b46c00ebd06c5afb0fb2150 |
|
MD5 | 4991920d3b4b65e8c64b1dd10b785cdb |
|
BLAKE2b-256 | 1647d9b741fbde0f5d6d46d17f7e47cd5bb8bb9a82718641cdc48a8fb9cb5a77 |