Python中的灵活快速抽样
项目描述
BlackJAX
什么是BlackJAX?
BlackJAX是一个JAX的采样库,可以在CPU以及GPU上运行。
它不是一个概率编程库。然而,只要它们可以提供一个与JAX兼容的(可能未归一化的)对数概率密度函数,它就可以很好地与PPLs集成。
谁应该使用BlackJAX?
BlackJAX应该对以下人群有吸引力
- 有一个logpdf并且只需要一个采样器;
- 需要比通用采样器更多;
- 希望在GPU上采样;
- 希望基于稳健的基本块进行研究;
- 正在构建一个概率编程语言;
- 想了解采样算法是如何工作的。
快速入门
安装
您可以使用pip
安装BlackJAX
pip install blackjax
或通过conda-forge
conda install -c conda-forge blackjax
BlackJAX是用纯Python编写的,但依赖于JAX的XLA。默认情况下,与BlackJAX一起安装的JAX版本将使您的代码仅在CPU上运行。如果您想在GPU/TPU上使用BlackJAX,我们建议您按照这些说明安装具有相关硬件加速支持的JAX。
示例
让我们来看一个简单的自包含示例,使用NUTS进行采样
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax
observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)
# Iterate
rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
请参阅文档了解更多如何使用库的示例:如何编写一个或多个链的推理循环,如何使用Stan预热等。
哲学
什么是BlackJAX?
BlackJAX弥合了“单行”框架和模块化、可定制的库之间的差距。
用户可以通过几行代码导入库并与健壮、经过良好测试和性能良好的采样器进行交互。这些采样器面向PPL开发者,或者只需要一个能工作的采样器的人。
但BlackJAX真正的优势在于其内部结构和如何快速实验现有或新的采样方案。这一层揭示了推理算法的构建块:积分器、建议、动量生成器等,并使其易于组合以构建新的算法。它通过提供健壮、性能良好和可重用的代码来加速采样算法的研究。
为什么选择BlackJAX?
采样算法通常被集成到PPL中,而没有从框架的其他部分解耦,这使得不需要建模语言来构建其logpdf的人难以使用。它们的实现通常是单体化的,不可能重用算法的一部分来构建定制的内核。BlackJAX解决了这两个问题。
它如何工作?
BlackJAX允许构建任意复杂的算法,因为它围绕一个非常通用的模式构建。任何接受状态并返回状态的东西都是一个转换内核,并且实现如下
new_state, info = kernel(rng_key, state)
内核是无状态的函数,并且所有都遵循相同的API;状态和与转换相关的信息被分别返回。因此,它们可以轻松组合和交换。我们通过闭包而不是传递参数来专门化这些内核。
贡献
请参阅我们的简要指南。
引用Blackjax
要引用此存储库
@misc{cabezas2024blackjax,
title={BlackJAX: Composable {B}ayesian inference in {JAX}},
author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
year={2024},
eprint={2402.10797},
archivePrefix={arXiv},
primaryClass={cs.MS}
}
在上面的bibtex条目中,名称按字母顺序排列,版本号应该是main
分支上的最后一个标签。
致谢
NUTS实现的某些细节在很大程度上受到了Numpyro的启发。
项目详情
下载文件
下载您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。