distax:JAX中的概率分布。
项目描述
distax
distax是一个轻量级的概率分布和反函数库。它作为JAX原生的TensorFlow概率(TFP)子集的重实现,并添加了一些新功能,注重可扩展性。
安装
您可以通过PyPI安装distax的最新发布版本:
pip install distrax
或者您可以从GitHub安装最新的开发版本
pip install git+https://github.com/deepmind/distrax.git
设计原则
DeepMind JAX生态系统的通用设计原则在这篇博客中有所阐述。此外,Distrax还强调了以下几点:
- 可读性。 Distrax的实现旨在尽可能独立,并尽可能接近底层的数学。
- 可扩展性。 我们已经尽可能简化了用户定义自己的分布或双射器的操作。这在强化学习中很有用,例如,用户可能希望为概率代理策略定义自定义行为。
- 兼容性。 Distrax并非旨在取代TFP,TFP包含许多我们不打算复制的先进功能。为此,我们已尽可能地使分布和双射器的API实现交叉兼容,并提供转换等效Distrax和TFP类别的实用工具。
功能
分布
Distrax中的分布易于定义和使用,尤其是如果您习惯了TFP。让我们将两者并排比较
import distrax
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)
dist_tfp = tfd.MultivariateNormalDiag(mu, sigma)
samples = dist_distrax.sample(seed=key)
# Both print 1.775
print(dist_distrax.log_prob(samples))
print(dist_tfp.log_prob(samples))
除了行为一致外,Distrax分布和TFP分布也是交叉兼容的。例如
mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu_0, sigma_0)
mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.MultivariateNormalDiag(mu_1, sigma_1)
# Both print 85.237
print(dist_distrax.kl_divergence(dist_tfp))
print(tfd.kl_divergence(dist_distrax, dist_tfp))
Distrax分布实现了sample_and_log_prob
方法,它在一行中提供了样本及其对数概率。对于某些分布,这比单独调用sample
和log_prob
更有效。
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])
dist_distrax = distrax.MultivariateNormalDiag(mu, sigma)
samples = dist_distrax.sample(seed=key, sample_shape=())
log_prob = dist_distrax.log_prob(samples)
# A one-line equivalent of the above is:
samples, log_prob = dist_distrax.sample_and_log_prob(seed=key, sample_shape=())
TFP分布可以作为输入传递给Distrax元分布。例如
key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.2, 0.3, 0.4])
dist_tfp = tfd.Normal(mu, sigma)
metadist_distrax = distrax.Independent(dist_tfp, reinterpreted_batch_ndims=1)
samples = metadist_distrax.sample(seed=key)
print(metadist_distrax.log_prob(samples)) # Prints 0.38871175
要在TFP元分布中使用Distrax分布,Distrax提供了包装器to_tfp
。包装后的Distrax分布可以直接在TFP中使用。
key = jax.random.PRNGKey(1234)
distrax_dist = distrax.Normal(0., 1.)
wrapped_dist = distrax.to_tfp(distrax_dist)
metadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])
samples = metadist_tfp.sample(seed=key)
print(metadist_tfp.log_prob(samples)) # Prints -3.3409896
双射器
Distrax中的“双射器”是一个可逆函数,它知道如何计算其雅可比行列式。双射器可以通过转换更简单的分布来创建复杂的分布。Distrax双射器在功能上类似于TFP双射器,但有少量API差异。以下是一个比较示例
import distrax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
# Same distribution.
distrax.Transformed(distrax.Normal(loc=0., scale=1.), distrax.Tanh())
tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), tfb.Tanh())
此外,Distrax双射器可以组合和求逆
bij_distrax = distrax.Tanh()
bij_tfp = tfb.Tanh()
# Same bijector.
inv_bij_distrax = distrax.Inverse(bij_distrax)
inv_bij_tfp = tfb.Invert(bij_tfp)
# These are both the identity bijector.
distrax.Chain([bij_distrax, inv_bij_distrax])
tfb.Chain([bij_tfp, inv_bij_tfp])
所有TFP双射器都可以传递给Distrax,可以与Distrax双射器自由组合。例如,以下所有操作都将正常工作
distrax.Inverse(tfb.Tanh())
distrax.Chain([tfb.Tanh(), distrax.Tanh()])
distrax.Transformed(tfd.Normal(loc=0., scale=1.), tfb.Tanh())
Distrax双射器也可以传递给TFP,但首先必须使用to_tfp
进行转换
bij_distrax = distrax.to_tfp(distrax.Tanh())
tfb.Invert(bij_distrax)
tfb.Chain([tfb.Tanh(), bij_distrax])
tfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), bij_distrax)
Distrax还提供了Lambda
,这是一个方便的包装器,可以将简单的JAX函数转换为双射器。以下是一些Lambda
示例及其TFP等价物
distrax.Lambda(lambda x: x)
# tfb.Identity()
distrax.Lambda(lambda x: 2*x + 3)
# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])
distrax.Lambda(jnp.sinh)
# tfb.Sinh()
distrax.Lambda(lambda x: jnp.sinh(2*x + 3))
# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])
与TFP不同,Distrax中的双射器在计算雅可比行列式时不将event_ndims
作为参数。相反,Distrax假设事件维度的数量对于每个双射器都是静态已知的,并使用Block
将双射器提升到不同的维度数。例如
x = jnp.zeros([2, 3, 4])
# In TFP, `event_ndims` can be passed to the bijector.
bij_tfp = tfb.Tanh()
ld_1 = bij_tfp.forward_log_det_jacobian(x, event_ndims=0) # Shape = [2, 3, 4]
# Distrax assumes `Tanh` is a scalar bijector by default.
bij_distrax = distrax.Tanh()
ld_2 = bij_distrax.forward_log_det_jacobian(x) # ld_1 == ld_2
# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.
ld_3 = bij_tfp.forward_log_det_jacobian(x, event_ndims=2) # Shape = [2]
# Distrax treats the number of dimensions statically.
bij_distrax = distrax.Block(bij_distrax, ndims=2)
ld_4 = bij_distrax.forward_log_det_jacobian(x) # ld_3 == ld_4
Distrax双射器实现了forward_and_log_det
方法(某些双射器还实现了inverse_and_log_det
),它可以在一行中获取前向映射及其对数雅可比行列式。对于某些双射器,这比单独调用forward
和forward_log_det_jacobian
更有效。(类似地,当可用时,inverse_and_log_det
比inverse
和inverse_log_det_jacobian
更有效。)
x = jnp.zeros([2, 3, 4])
bij_distrax = distrax.Tanh()
y = bij_distrax.forward(x)
ld = bij_distrax.forward_log_det_jacobian(x)
# A one-line equivalent of the above is:
y, ld = bij_distrax.forward_and_log_det(x)
对Distrax进行Jitting
Distrax分布和双射器可以作为参数传递给Jitted函数。用户定义的分布和双射器通过分别继承distrax.Distribution
和distrax.Bijector
分别获得此属性。例如
mu_0 = jnp.array([-1., 0., 1.])
sigma_0 = jnp.array([0.1, 0.2, 0.3])
dist_0 = distrax.MultivariateNormalDiag(mu_0, sigma_0)
mu_1 = jnp.array([1., 2., 3.])
sigma_1 = jnp.array([0.2, 0.3, 0.4])
dist_1 = distrax.MultivariateNormalDiag(mu_1, sigma_1)
jitted_kl = jax.jit(lambda d_0, d_1: d_0.kl_divergence(d_1))
# Both print 85.237
print(jitted_kl(dist_0, dist_1))
print(dist_0.kl_divergence(dist_1))
关于vmap
和pmap
的注意事项
使Distrax对象可以作为参数传递给Jitted函数的序列化逻辑,还允许使用jax.vmap
和jax.pmap
在它们上执行映射操作。
然而,对这种行为的支持是实验性和不完整的。在将 jax.vmap
或 jax.pmap
应用到以 Distrax 对象为参数或返回 Distrax 对象的函数时,请谨慎行事。
简单的对象,如 distrax.Categorical
,在这些变换下可能表现正常,但更复杂的对象,如 distrax.MultivariateNormalDiag
,在用作 vmap
-ed 或 pmap
-ed 函数的输入或输出时可能会引发异常。
分布和反函数的子类化
可以通过子类化 distrax.Distribution
创建用户定义的分布。这可以通过仅实现几个方法来实现。
class MyDistribution(distrax.Distribution):
def __init__(self, ...):
...
def _sample_n(self, key, n):
samples = ...
return samples
def log_prob(self, value):
log_prob = ...
return log_prob
def event_shape(self):
event_shape = ...
return event_shape
def _sample_n_and_log_prob(self, key, n):
# Optional. Only when more efficient implementation is possible.
samples, log_prob = ...
return samples, log_prob
同样,可以通过子类化 distrax.Bijector
创建更复杂的反函数。这可以通过仅实现一个或两个类方法来实现。
class MyBijector(distrax.Bijector):
def __init__(self, ...):
super().__init__(...)
def forward_and_log_det(self, x):
y = ...
logdet = ...
return y, logdet
def inverse_and_log_det(self, y):
# Optional. Can be omitted if inverse methods are not needed.
x = ...
logdet = ...
return x, logdet
示例
在 examples
目录中包含了一些使用 Distrax 的完整程序的代表性示例。
hmm.py
展示了如何使用 distrax.HMM
将模型初始状态、转移和观测分布的分布组合起来,并在变化的噪声信号中推断潜在率和状态转移。
vae.py
包含了一个变分自动编码器的示例实现,该编码器被训练来将二进制化的 MNIST 数据集作为像素上的联合 distrax.Bernoulli
分布进行建模。
flow.py
说明了使用 distrax.MaskedCoupling
层建模 MNIST 数据的简单示例,并使用梯度下降训练模型。
致谢
我们非常感谢 TensorFlow Probability 作者在协助设计 Distrax 和跨兼容性方面提供的持续支持。
特别感谢 Aleyna Kara 和 Kevin Murphy 为基于该代码的隐马尔可夫模型和相关示例所做的贡献。
引用 Distrax
此存储库是 DeepMind JAX 生态系统的一部分,要引用 Distrax,请使用以下引用。
@software{deepmind2020jax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/deepmind},
year = {2020},
}
项目详情
下载文件
下载适合您平台的应用程序。如果您不确定选择哪个,请了解有关 安装包 的更多信息。