跳转到主要内容

盆栽:一个用于构建、编辑和可视化神经网络的JAX研究工具包。

项目描述

盆栽

盆 ("pen", 盆) 栽 ("zai", 种植) - 一种古代中国艺术,通过微型盆景和山水来造型,也称为盆景,是日本艺术盆栽的鼻祖。

盆栽是一个JAX库,用于将模型作为可读的、功能性的pytree数据结构编写,并提供用于可视化、修改和分析它们的工具。盆栽专注于 在模型训练后轻松进行操作,使其成为涉及逆向工程或移除模型组件、检查和探索内部激活、执行模型手术、调试架构等研究的优秀选择。(但如果你只是想构建和训练模型,你也能做到!)

使用盆栽,你的神经网络可能看起来像这样

Screenshot of the Gemma model in Penzai

盆栽作为一个模块化工具集合构建,虽然设计在一起,但每个都可以独立使用

  • 一个超级强大的交互式Python美化打印器

    • Treescope (pz.ts):一个普通IPython/Colab渲染器的替代品,最初是盆栽的一部分,但现在作为一个独立包提供。它旨在帮助理解盆栽模型和其他深层嵌套的JAX pytrees,内置对任意维度的NDArrays的可视化支持。
  • 一套JAX树和数组操作工具

    • penzai.core.selectors (pz.select):一个pytree瑞士军刀,将JAX的.at[...].set(...)语法泛化为任意类型驱动的pytree遍历,并使复杂的重写或即时修补Penzai模型和其他数据结构变得简单。

    • penzai.core.named_axes (pz.nx):一个轻量级的命名轴系统,可以将普通JAX函数提升为命名轴的矢量化,并允许您在不学习新的数组API的情况下无缝地在命名和位置编程风格之间切换。

  • 一个声明性组合器神经网络的库,模型以易于修改的数据结构表示

    • penzai.nn (pz.nn):Flax、Haiku、Keras或Equinox等其他神经网络库的替代品,使用声明性组合器公开模型的正向传递的完整结构。像Equinox一样,模型以JAX PyTrees的形式表示,这意味着您可以通过美化打印来查看模型的所有操作,并使用jax.tree_util注入新的运行时逻辑。penzai.nn模型还可能包含树叶处的可变变量,允许它们跟踪可变状态和参数共享。
  • 常见Transformer架构的模块化实现,以支持可解释性、模型手术和训练动态的研究

    • penzai.models.transformer:一个参考Transformer实现,可以加载Gemma、Llama、Mistral和GPT-NeoX / Pythia架构的预训练权重。使用模块化组件和命名轴构建,以简化复杂的模型操作工作流程。

Penzai的文档可以在https://penzai.readthedocs.io找到。

[!重要] Penzai 0.2包含对神经网络API的一些重大更改。这些更改旨在通过引入对可变状态和参数共享的第一类支持以及删除不必要的样板代码来简化常见的工作流程。您可以在"V2 API更改"概述中了解旧"V1" API和当前"V2" API之间的差异。

如果您目前正在使用V1 API并且尚未转换为V2系统,您可以通过从penzai.deprecated.v1子模块导入来保留旧行为,例如:

from penzai.deprecated.v1 import pz
from penzai.deprecated.v1.example_models import simple_mlp

入门指南

如果您还没有安装JAX,您应该首先安装它,因为安装过程取决于您的平台。您可以在JAX文档中找到说明。之后,您可以使用以下命令安装Penzai:

pip install penzai

并使用以下命令导入它:

import penzai
from penzai import pz

(penzai.pz是一个别名命名空间,这使得引用常见的Penzai对象变得更加容易。)

当在Colab或IPython笔记本中工作时,我们建议还将Treescope(Penzai的配套美化打印器)配置为默认美化打印器,并启用一些用于交互式使用的实用工具

import treescope
treescope.basic_interactive_setup(autovisualize_arrays=True)

以下是初始化和可视化简单神经网络的示例

from penzai.models import simple_mlp
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8]
)

# Models and arrays are visualized automatically when you output them from a
# Colab/IPython notebook cell:
mlp

以下是捕获和提取元素非线性后的激活的示例

@pz.pytree_dataclass
class AppendIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[list[Any]]
  def __call__(self, x: Any, **unused_side_inputs) -> Any:
    self.saved.value = self.saved.value + [x]
    return x

var = pz.StateVariable(value=[], label="my_intermediates")

# Make a copy of the model that saves its activations:
saving_model = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(AppendIntermediate(var))
)

output = saving_model(pz.nx.ones({"features": 8}))
intermediates = var.value

要了解更多关于如何使用Penzai构建和操作神经网络的信息,我们建议从"如何在Penzai中思考"教程Penzai文档中的其他教程开始。

引用

如果您发现Penzai对您的研究很有用,请考虑引用以下撰写(也可在arXiv上找到)

@article{johnson2024penzai,
    author={Daniel D. Johnson},
    title={{Penzai} + {Treescope}: A Toolkit for Interpreting, Visualizing, and Editing Models As Data},
    year={2024},
    journal={ICML 2024 Workshop on Mechanistic Interpretability}
}

这不是一个官方支持的Google产品。

项目详情


下载文件

下载适用于您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源代码分发

penzai-0.2.2.tar.gz (36.6 MB 查看哈希值)

上传时间 源代码

构建分发

penzai-0.2.2-py3-none-any.whl (314.5 kB 查看哈希值)

上传时间 Python 3