跳转到主要内容

Chex:在JAX中让测试变得有趣!

项目描述

Chex

CI status docs pypi

Chex 是一个帮助编写可靠 JAX 代码的实用程序库。

这包括以下实用程序:

  • 对代码进行仪表化(例如,断言、警告)
  • 调试(例如,在上下文管理器中将 pmaps 转换为 vmaps)。
  • 在许多变体中测试JAX代码(例如,编译过的与未编译过的)。

安装

您可以通过PyPI安装Chex的最新发布版本:

pip install chex

或者您可以从GitHub安装最新开发版本

pip install git+https://github.com/deepmind/chex.git

模块概述

Dataclass (dataclass.py)

Dataclass是Python 3.7引入的一种流行结构,允许您通过最少的样板代码轻松指定类型化数据结构。然而,它们与JAX和dm-tree并不完全兼容。

在Chex中,我们提供了一个JAX友好的dataclass实现,重新使用了Python dataclasses

Chex的dataclass实现将数据类注册为内部PyTree节点,以确保与JAX数据结构的兼容性。

此外,我们还提供了一个类包装器,将数据类公开为collections.Mapping的子类,这使得它们可以在dm-tree方法中像常规Python字典一样进行处理(例如,(un-)flatten)。有关更多详细信息,请参阅@mappable_dataclass文档字符串。

示例

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_util.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

注意:与标准Python 3.7数据类不同,Chex数据类不能使用位置参数构造。它们支持以Python字典构造函数相同的格式提供的构造参数。如果需要,数据类可以使用from_tupleto_tuple方法转换为元组。

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

断言(《asserts.py》

PyType注解对JAX的一个限制是,它们不支持指定DeviceArray的秩、形状或dtypes。Chex包含一些函数,允许灵活且简洁地指定这些属性。

例如,假设您想确保所有张量t1t2t3具有相同的形状,并且张量t4t5的秩分别为2和(3或4)。

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

更多示例

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

请参阅asserts.py 文档以找到所有支持的断言。

如果您找不到特定的断言,请考虑在问题跟踪器上提交拉取请求或打开问题。

可选参数

所有chex断言都支持以下可选kwargs来操作生成的异常消息:

  • custom_message:一个字符串,将包含在生成的异常消息中。
  • include_default_message:是否将默认Chex消息包含在生成的异常消息中。
  • exception_type:要使用的异常类型。默认为AssertionError

例如,以下代码

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

将在params被污染为NaNsNone时引发一个包含步骤号的ValueError

静态和值(又称Runtime)断言

Chex将所有断言分为两类:staticvalue断言。

  1. static断言使用除了张量的具体值之外的所有内容。示例:assert_shapeassert_trees_all_equal_dtypesassert_max_traces

  2. value断言需要访问张量值,这些值在JAX跟踪期间不可用(请参阅JAX原语如何工作),因此这些断言需要在jitted代码中进行特殊处理。

要启用jitted函数中的值断言,它可以由chex.chexify()包装器装饰。示例

  @chex.chexify
  @jax.jit
  def logp1_abs_safe(x: chex.Array) -> chex.Array:
    chex.assert_tree_all_finite(x)
    return jnp.log(jnp.abs(x) + 1)

  logp1_abs_safe(jnp.ones(2))  # OK
  logp1_abs_safe(jnp.array([jnp.nan, 3]))  # FAILS (in async mode)

  # The error will be raised either at the next line OR at the next
  # `logp1_abs_safe` call. See the docs for more detain on async mode.
  logp1_abs_safe.wait_checks()  # Wait for the (async) computation to complete.

请参阅此文档字符串以获取有关chex.chexify()的更多详细信息。

JAX跟踪断言

JAX 在每次传入参数结构发生变化时都会重新追踪JIT编译的函数。这种行为往往是无意中发生的,并且会导致性能显著下降,难以调试。请参阅@chex.assert_max_traces装饰器,该装饰器确保函数在程序执行期间不会超过n次重新追踪。

可以通过调用chex.clear_trace_counter()来清除全局追踪计数器。此函数可用于隔离依赖于@chex.assert_max_traces的unittest。

示例

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))  # tracing for the 1st time - OK
  fn_sum_jitted(jnp.zeros([6, 7]), jnp.zeros([6, 7]))  # AssertionError!

也可以与jax.pmap()一起使用

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

有关追踪的更多信息,请参阅JAX原语如何工作部分。

警告(warnings.py

除了硬断言外,Chex还提供了一些实用工具来添加常见的警告,例如特定的弃用警告。

测试变体(variants.py

JAX大量依赖代码转换和编译,这意味着确保代码得到适当测试可能很困难。例如,仅使用JAX代码测试Python函数将不会覆盖JIT执行时实际执行的代码路径,而这个路径也会因代码是针对CPU、GPU还是TPU进行JIT编译而有所不同。这可能是XLA更改导致的不希望的行为的来源,但这些行为只在特定的代码转换中表现出来。

变体通过提供一个简单的装饰器,可以重复测试任何测试在所有(或相关代码转换的子集)下,从而使得确保单元测试覆盖函数的不同“变体”变得容易。

例如,假设你想测试函数fn是否有或没有jit。你可以使用chex.variants通过简单地装饰一个测试方法来运行带有函数的jitted和非jitted版本的测试,然后在测试方法的主体中使用self.variant(fn)代替fn

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

如果你在测试方法中定义了函数,你还可以在函数定义中使用self.variant作为装饰器。例如

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

参数化测试示例

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex目前支持以下变体

  • with_jit -- 将jax.jit()转换应用于函数。
  • without_jit -- 使用原始函数,即恒等转换。
  • with_device -- 在应用函数之前,将所有参数(除了在ignore_argnums参数中指定的参数)放入设备内存。
  • without_device -- 在应用函数之前,将所有参数放入RAM中。
  • with_pmap -- 将jax.pmap()转换应用于函数(见以下说明)。

有关支持的变体的详细信息,请参阅variants.py文档。更多示例可以在variants_test.py中找到。

变体说明

  • 使用@chex.variants的测试类必须从chex.TestCase(或任何其他在TestCase内部展开测试生成器的基类,例如absl.testing.parameterized.TestCase)继承。

  • [jax.vmap] 所有变体都可以应用于vmapped函数;请参阅variants_test.py中的示例(test_vmapped_fn_named_paramstest_pmap_vmapped_fn)。

  • [@chex.all_variants] 你可以使用装饰器@chex.all_variants来获取所有支持的变体。

  • [with_pmap 变体] jax.pmap(fn) (文档) 在多个设备上对 fn 执行并行映射。由于大多数测试都在单设备环境中运行(即只能访问单个 CPU 或 GPU),在这种情况下,jax.pmapjax.jit 功能相同,因此默认跳过 with_pmap 变体(尽管它与单个设备配合工作也很正常)。以下我们将描述一种在多设备环境(TPU 或多个 CPU/GPU)中正确测试 fn 的方法。如果要在单设备的情况下禁用跳过 with_pmap 变体,请在测试命令中添加 --chex_skip_pmap_variant_if_single_device=false

模拟 (fake.py)

JAX 中的调试由于代码转换(如 jitpmap)而变得更加困难,这些转换引入了优化,使得代码难以检查和跟踪。在调试期间禁用这些转换也可能很困难,因为它们可以在底层代码的多个位置调用。Chex 提供了工具,可以全局地将 jax.jit 替换为无操作转换,并将 jax.pmap 替换为(非并行)jax.vmap,以便更容易地在单设备环境中调试代码。

例如,您可以使用 Chex 模拟 pmap 并将其替换为 vmap。这可以通过将您的代码包裹在上下文管理器中实现

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

同样,您也可以使用 startstop 来调用相同的功能

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

此外,您还可以使用多线程 CPU 模拟真实的多个设备测试环境。有关更多详细信息,请参阅 模拟多个设备测试环境 部分。

有关更多详细信息,请参阅 fake.pyfake_test.py 中的示例。

模拟多个设备测试环境

在不方便访问多个设备的情况下,您仍然可以使用单设备多线程测试并行计算。

特别是,可以强制 XLA 使用单个 CPU 的线程作为单独的设备,即使用多线程模拟真实的多个设备环境。从 XLA 角度来看,这两个选项在理论上是等效的,因为它们公开了相同的接口并使用相同的抽象。

Chex 有一个标志 chex_n_cpu_devices,用于指定要用作 XLA 设备的 CPU 线程数。

要为 absl 测试设置多线程 XLA 环境,在您的测试模块中定义 setUpModule 函数

def setUpModule():
  chex.set_n_cpu_devices()

现在您可以使用 python test.py --chex_n_cpu_devices=N 启动您的测试,以多设备模式运行。请注意,模块中的所有测试都将访问 N 个设备。

更多示例请参阅 variants_test.pyfake_test.pyfake_set_n_cpu_devices_test.py

使用命名维度大小。

Chex 提供了一个小工具,允许您将一组维度大小打包成一个单一的对象。基本思想是

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
chex.assert_shape(arr, dims['BTE'])

字符串查找被转换为整数元组。例如,假设 batch_size == 3sequence_len = 5embedding_dim = 7,那么

dims['BTE'] == (3, 5, 7)
dims['B'] == (3,)
dims['TTBEE'] == (5, 5, 3, 7, 7)
...

您也可以按如下方式动态分配维度大小

dims['XY'] = some_matrix.shape
dims.Z = 13

有关更多示例,请参阅 chex.Dimensions 文档。

引用 Chex

此存储库是 DeepMind JAX 生态系统 的一部分,要引用 Chex,请使用 DeepMind JAX 生态系统引用

项目详情


下载文件

下载您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。

源分布

chex-0.1.87.tar.gz (90.1 kB 查看哈希值)

上传时间

构建分布

chex-0.1.87-py3-none-any.whl (99.4 kB 查看哈希值)

上传时间 Python 3