Chex:在JAX中让测试变得有趣!
项目描述
Chex
Chex 是一个帮助编写可靠 JAX 代码的实用程序库。
这包括以下实用程序:
- 对代码进行仪表化(例如,断言、警告)
- 调试(例如,在上下文管理器中将
pmaps
转换为vmaps
)。 - 在许多
变体
中测试JAX代码(例如,编译过的与未编译过的)。
安装
您可以通过PyPI安装Chex的最新发布版本:
pip install chex
或者您可以从GitHub安装最新开发版本
pip install git+https://github.com/deepmind/chex.git
模块概述
Dataclass (dataclass.py)
Dataclass是Python 3.7引入的一种流行结构,允许您通过最少的样板代码轻松指定类型化数据结构。然而,它们与JAX和dm-tree并不完全兼容。
在Chex中,我们提供了一个JAX友好的dataclass实现,重新使用了Python dataclasses。
Chex的dataclass
实现将数据类注册为内部PyTree节点,以确保与JAX数据结构的兼容性。
此外,我们还提供了一个类包装器,将数据类公开为collections.Mapping
的子类,这使得它们可以在dm-tree
方法中像常规Python字典一样进行处理(例如,(un-)flatten)。有关更多详细信息,请参阅@mappable_dataclass
文档字符串。
示例
@chex.dataclass
class Parameters:
x: chex.ArrayDevice
y: chex.ArrayDevice
parameters = Parameters(
x=jnp.ones((2, 2)),
y=jnp.ones((1, 2)),
)
# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)
# and as mappings by dm-tree
tree.flatten(parameters)
注意:与标准Python 3.7数据类不同,Chex数据类不能使用位置参数构造。它们支持以Python字典构造函数相同的格式提供的构造参数。如果需要,数据类可以使用from_tuple
和to_tuple
方法转换为元组。
parameters = Parameters(
jnp.ones((2, 2)),
jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.
断言(《asserts.py》
PyType注解对JAX的一个限制是,它们不支持指定DeviceArray
的秩、形状或dtypes。Chex包含一些函数,允许灵活且简洁地指定这些属性。
例如,假设您想确保所有张量t1
、t2
、t3
具有相同的形状,并且张量t4
、t5
的秩分别为2和(3或4)。
chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])
更多示例
from chex import assert_shape, assert_rank, ...
assert_shape(x, (2, 3)) # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)]) # x is scalar and y has shape (2, 3)
assert_rank(x, 0) # x is scalar
assert_rank([x, y], [0, 2]) # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2}) # x and y are scalar OR rank-2 arrays
assert_type(x, int) # x has type `int` (x can be an array)
assert_type([x, y], [int, float]) # x has type `int` and y has type `float`
assert_equal_shape([x, y, z]) # x, y, and z have equal shapes
assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x) # all tree_x leaves are finite
assert_devices_available(2, 'gpu') # 2 GPUs available
assert_tpu_available() # at least 1 TPU available
assert_numerical_grads(f, (x, y), j) # f^{(j)}(x, y) matches numerical grads
请参阅asserts.py
文档以找到所有支持的断言。
如果您找不到特定的断言,请考虑在问题跟踪器上提交拉取请求或打开问题。
可选参数
所有chex断言都支持以下可选kwargs来操作生成的异常消息:
custom_message
:一个字符串,将包含在生成的异常消息中。include_default_message
:是否将默认Chex消息包含在生成的异常消息中。exception_type
:要使用的异常类型。默认为AssertionError
。
例如,以下代码
dataset = load_dataset()
params = init_params()
for i in range(num_steps):
params = update_params(params, dataset.sample())
chex.assert_tree_all_finite(params,
custom_message=f'Failed at iteration {i}.',
exception_type=ValueError)
将在params
被污染为NaNs
或None
时引发一个包含步骤号的ValueError
。
静态和值(又称Runtime)断言
Chex将所有断言分为两类:static和value断言。
-
static断言使用除了张量的具体值之外的所有内容。示例:
assert_shape
、assert_trees_all_equal_dtypes
、assert_max_traces
。 -
value断言需要访问张量值,这些值在JAX跟踪期间不可用(请参阅JAX原语如何工作),因此这些断言需要在jitted代码中进行特殊处理。
要启用jitted函数中的值断言,它可以由chex.chexify()
包装器装饰。示例
@chex.chexify
@jax.jit
def logp1_abs_safe(x: chex.Array) -> chex.Array:
chex.assert_tree_all_finite(x)
return jnp.log(jnp.abs(x) + 1)
logp1_abs_safe(jnp.ones(2)) # OK
logp1_abs_safe(jnp.array([jnp.nan, 3])) # FAILS (in async mode)
# The error will be raised either at the next line OR at the next
# `logp1_abs_safe` call. See the docs for more detain on async mode.
logp1_abs_safe.wait_checks() # Wait for the (async) computation to complete.
请参阅此文档字符串以获取有关chex.chexify()
的更多详细信息。
JAX跟踪断言
JAX 在每次传入参数结构发生变化时都会重新追踪JIT编译的函数。这种行为往往是无意中发生的,并且会导致性能显著下降,难以调试。请参阅@chex.assert_max_traces装饰器,该装饰器确保函数在程序执行期间不会超过n
次重新追踪。
可以通过调用chex.clear_trace_counter()
来清除全局追踪计数器。此函数可用于隔离依赖于@chex.assert_max_traces
的unittest。
示例
@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
return x + y
fn_sum_jitted(jnp.zeros(3), jnp.zeros(3)) # tracing for the 1st time - OK
fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7])) # AssertionError!
也可以与jax.pmap()
一起使用
def fn_sub(x, y):
return x - y
fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))
有关追踪的更多信息,请参阅JAX原语如何工作部分。
警告(warnings.py)
除了硬断言外,Chex还提供了一些实用工具来添加常见的警告,例如特定的弃用警告。
测试变体(variants.py)
JAX大量依赖代码转换和编译,这意味着确保代码得到适当测试可能很困难。例如,仅使用JAX代码测试Python函数将不会覆盖JIT执行时实际执行的代码路径,而这个路径也会因代码是针对CPU、GPU还是TPU进行JIT编译而有所不同。这可能是XLA更改导致的不希望的行为的来源,但这些行为只在特定的代码转换中表现出来。
变体通过提供一个简单的装饰器,可以重复测试任何测试在所有(或相关代码转换的子集)下,从而使得确保单元测试覆盖函数的不同“变体”变得容易。
例如,假设你想测试函数fn
是否有或没有jit。你可以使用chex.variants
通过简单地装饰一个测试方法来运行带有函数的jitted和非jitted版本的测试,然后在测试方法的主体中使用self.variant(fn)
代替fn
。
def fn(x, y):
return x + y
...
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
var_fn = self.variant(fn)
self.assertEqual(fn(1, 2), 3)
self.assertEqual(var_fn(1, 2), fn(1, 2))
如果你在测试方法中定义了函数,你还可以在函数定义中使用self.variant
作为装饰器。例如
class ExampleTest(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def var_fn(x, y):
return x + y
self.assertEqual(var_fn(1, 2), 3)
参数化测试示例
from absl.testing import parameterized
# Could also be:
# `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
# `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
('case_positive', 1, 2, 3),
('case_negative', -1, -2, -3),
)
def test(self, arg_1, arg_2, expected):
@self.variant
def var_fn(x, y):
return x + y
self.assertEqual(var_fn(arg_1, arg_2), expected)
Chex目前支持以下变体
with_jit
-- 将jax.jit()
转换应用于函数。without_jit
-- 使用原始函数,即恒等转换。with_device
-- 在应用函数之前,将所有参数(除了在ignore_argnums
参数中指定的参数)放入设备内存。without_device
-- 在应用函数之前,将所有参数放入RAM中。with_pmap
-- 将jax.pmap()
转换应用于函数(见以下说明)。
有关支持的变体的详细信息,请参阅variants.py文档。更多示例可以在variants_test.py中找到。
变体说明
-
使用
@chex.variants
的测试类必须从chex.TestCase
(或任何其他在TestCase
内部展开测试生成器的基类,例如absl.testing.parameterized.TestCase
)继承。 -
[
jax.vmap
] 所有变体都可以应用于vmapped函数;请参阅variants_test.py中的示例(test_vmapped_fn_named_params
和test_pmap_vmapped_fn
)。 -
[
@chex.all_variants
] 你可以使用装饰器@chex.all_variants
来获取所有支持的变体。 -
[
with_pmap
变体]jax.pmap(fn)
(文档) 在多个设备上对fn
执行并行映射。由于大多数测试都在单设备环境中运行(即只能访问单个 CPU 或 GPU),在这种情况下,jax.pmap
与jax.jit
功能相同,因此默认跳过with_pmap
变体(尽管它与单个设备配合工作也很正常)。以下我们将描述一种在多设备环境(TPU 或多个 CPU/GPU)中正确测试fn
的方法。如果要在单设备的情况下禁用跳过with_pmap
变体,请在测试命令中添加--chex_skip_pmap_variant_if_single_device=false
。
模拟 (fake.py)
JAX 中的调试由于代码转换(如 jit
和 pmap
)而变得更加困难,这些转换引入了优化,使得代码难以检查和跟踪。在调试期间禁用这些转换也可能很困难,因为它们可以在底层代码的多个位置调用。Chex 提供了工具,可以全局地将 jax.jit
替换为无操作转换,并将 jax.pmap
替换为(非并行)jax.vmap
,以便更容易地在单设备环境中调试代码。
例如,您可以使用 Chex 模拟 pmap
并将其替换为 vmap
。这可以通过将您的代码包裹在上下文管理器中实现
with chex.fake_pmap():
@jax.pmap
def fn(inputs):
...
# Function will be vmapped over inputs
fn(inputs)
同样,您也可以使用 start
和 stop
来调用相同的功能
fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()
此外,您还可以使用多线程 CPU 模拟真实的多个设备测试环境。有关更多详细信息,请参阅 模拟多个设备测试环境 部分。
有关更多详细信息,请参阅 fake.py 和 fake_test.py 中的示例。
模拟多个设备测试环境
在不方便访问多个设备的情况下,您仍然可以使用单设备多线程测试并行计算。
特别是,可以强制 XLA 使用单个 CPU 的线程作为单独的设备,即使用多线程模拟真实的多个设备环境。从 XLA 角度来看,这两个选项在理论上是等效的,因为它们公开了相同的接口并使用相同的抽象。
Chex 有一个标志 chex_n_cpu_devices
,用于指定要用作 XLA 设备的 CPU 线程数。
要为 absl
测试设置多线程 XLA 环境,在您的测试模块中定义 setUpModule
函数
def setUpModule():
chex.set_n_cpu_devices()
现在您可以使用 python test.py --chex_n_cpu_devices=N
启动您的测试,以多设备模式运行。请注意,模块中的所有测试都将访问 N
个设备。
更多示例请参阅 variants_test.py、fake_test.py 和 fake_set_n_cpu_devices_test.py。
使用命名维度大小。
Chex 提供了一个小工具,允许您将一组维度大小打包成一个单一的对象。基本思想是
dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])
字符串查找被转换为整数元组。例如,假设 batch_size == 3
、sequence_len = 5
和 embedding_dim = 7
,那么
dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...
您也可以按如下方式动态分配维度大小
dims['XY'] = some_matrix.shape
dims.Z = 13
有关更多示例,请参阅 chex.Dimensions 文档。
引用 Chex
此存储库是 DeepMind JAX 生态系统 的一部分,要引用 Chex,请使用 DeepMind JAX 生态系统引用。
项目详情
下载文件
下载您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。