跳转到主要内容

一组杂项辅助函数、自定义分布和其他实用工具,这些工具在我使用NumPyro进行工作时非常有用。

项目描述

NumPyro扩展

此库包括一组杂项辅助函数、自定义分布和其他实用工具,这些工具在我使用NumPyro进行工作时非常有用。

安装

由于NumPyro和此库都是基于JAX构建的,因此通常建议首先按照安装说明安装JAX。然后,您可以使用pip安装此库

python -m pip install numpyro-ext

用法

由于此README使用doctest进行验证,让我们首先导入我们将在所有示例中需要的常见模块

>>> import jax
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro_ext

分布

传统上,将numpyro_ext.distributions导入为distx,以区别于导入为distnumpyro.distributions

>>> from numpyro import distributions as dist
>>> from numpyro_ext import distributions as distx
>>> key = jax.random.PRNGKey(0)

角度

在弧度内对角度进行均匀分布。实际的采样是在与(sin(theta), cos(theta))成比例的两个维度的向量空间中进行的,这样采样器就不会在π处看到不连续性。

>>> angle = distx.Angle()
>>> print(angle.sample(key, (2, 3)))
[[ 0.4...]
 [ 2.4...]]

单位圆盘

在半径为1的圆盘内的二维点上的均匀分布。这意味着从该分布生成的随机变量的最后维度的平方和总是小于1。

>>> unit_disk = distx.UnitDisk()
>>> u = unit_disk.sample(key, (5,))
>>> print(jnp.sum(u**2, axis=-1))
[0.07...]

非中心卡方

非中心卡方分布。要使用此分布,您需要安装可选的tensorflow-probability依赖项。

>>> ncx2 = distx.NoncentralChi2(df=3, nc=2.)
>>> print(ncx2.sample(key, (5,)))
[2.19...]

边际化线性

两个(可能为多元)正态分布的边际化乘积,它们之间存在线性关系。这些模型的数学细节在这篇笔记中进行了详细讨论,并且此分布以计算效率的方式实现了那里呈现的数学,假设边际参数的数量相对于数据集的大小较小。

以下示例显示了一个将直线拟合到数据的完全边际化模型的特别简单示例

>>> def model(x, y=None):
...     design_matrix = jnp.vander(x, 2)
...     prior = dist.Normal(0.0, 1.0)
...     data = dist.Normal(0.0, 2.0)
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...

当设计矩阵和/或分布是非线性参数的函数时,事情会变得更有趣。例如,如果我们想找到正弦信号的周期,同时拟合一些未知的过剩测量不确定性(通常称为“抖动”),我们可以使用以下模型

>>> def model(x, y_err, y=None):
...     period = numpyro.sample("period", dist.Uniform(1.0, 250.0))
...     ln_jitter = numpyro.sample("ln_jitter", dist.Normal(0.0, 2.0))
...     design_matrix = jnp.stack(
...         [
...             jnp.sin(2 * jnp.pi * x / period),
...             jnp.cos(2 * jnp.pi * x / period),
...             jnp.ones_like(x),
...         ],
...         axis=-1,
...     )
...     prior = dist.Normal(0.0, 10.0).expand([3])
...     data = dist.Normal(0.0, jnp.sqrt(y_err**2 + jnp.exp(2*ln_jitter)))
...     numpyro.sample(
...         "y",
...         distx.MarginalizedLinear(design_matrix, prior, data),
...         obs=y
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> samples = numpyro.infer.Predictive(model, num_samples=2)(key, x, 0.1)
>>> print(samples["period"])
[... ...]
>>> print(samples["y"])
[[... ... ...]
 [... ... ...]]

在推理过程中跟踪边际参数的条件样本通常很有用。可以使用MarginalizedLinear上的conditional方法访问条件分布

>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # just some fake data
>>> design_matrix = jnp.vander(x, 2)
>>> prior = dist.Normal(0.0, 1.0)
>>> data = dist.Normal(0.0, 2.0)
>>> marg = distx.MarginalizedLinear(design_matrix, prior, data)
>>> cond = marg.conditional(y)
>>> print(type(cond).__name__)
MultivariateNormal
>>> print(cond.sample(key, (3,)))
[[...]
 [...]
 [...]]

优化

关于优化作为MCMC初始化工具的益处,推断知识略有分歧,但我在许多天文学应用中发现,至少在这些应用中,初始优化可以在性能上产生巨大差异。即使您不想使用优化结果作为初始化,有时也可以在数值上搜索模型的最大后验参数。然而,NumPyro接口对这些类型的优化并不是非常用户友好,因此这个库提供了一些辅助工具,使其更加直接。

默认情况下,此优化使用由JAXopt库提供的scipy优化例程的包装器,因此您需要在运行这些示例之前安装JAXopt

python -m pip install jaxopt

以下示例显示了对具有单个参数的模型进行简单优化的示例

>>> from numpyro_ext import optim as optimx
>>>
>>> def model(y=None):
...     x = numpyro.sample("x", dist.Normal(0.0, 1.0))
...     numpyro.sample("y", dist.Normal(x, 2.0), obs=y)
...
>>> soln = optimx.optimize(model)(key, y=0.5)

默认情况下,优化从先验样本开始,但您可以提供以下自定义初始坐标

>>> soln = optimx.optimize(model, start={"x": 12.3})(key, y=0.5)

同样,如果您只想优化参数的子集,您可以提供要针对的参数列表

>>> soln = optimx.optimize(model, sites=["x"])(key, y=0.5)

信息矩阵计算

具有高斯似然函数的模型的信息矩阵计算简单,并且此库提供了一个辅助函数来自动化此计算

>>> from numpyro_ext import information
>>>
>>> def model(x, y=None):
...     a = numpyro.sample("a", dist.Normal(0.0, 1.0))
...     b = numpyro.sample("b", dist.Normal(0.0, 1.0))
...     log_alpha = numpyro.sample("log_alpha", dist.Normal(0.0, 1.0))
...     cov = jnp.exp(log_alpha - 0.5 * (x[:, None] - x[None, :])**2)
...     cov += 0.1 * jnp.eye(len(x))
...     numpyro.sample(
...         "y",
...         dist.MultivariateNormal(loc=a * x + b, covariance_matrix=cov),
...         obs=y,
...     )
...
>>> x = jnp.linspace(-1.0, 1.0, 5)
>>> y = jnp.sin(x)  # the input data just needs to have the right shape
>>> params = {"a": 0.5, "b": -0.2, "log_alpha": -0.5}
>>> info = information(model)(params, x, y=y)
>>> print(info)
{'a': {'a': ..., 'b': ... 'log_alpha': ...}, 'b': ...}

返回的信息矩阵是一个嵌套字典的字典,由参数名称的配对索引,其中值是信息矩阵的相应块。

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分布

numpyro-ext-0.0.4.tar.gz (29.4 kB 查看哈希值)

上传时间

构建分布

numpyro_ext-0.0.4-py3-none-any.whl (21.1 kB 查看哈希值)

上传时间 Python 3

由以下支持