跳转到主要内容

Python中的灵活快速采样

项目描述

BlackJAX

Continuous integration codecov PyPI version

BlackJAX animation: sampling BlackJAX with BlackJAX

什么是BlackJAX?

BlackJAX是一个用于JAX的采样库,它可以在CPU和GPU上运行。

它不是一个概率编程库。然而,只要它们可以提供与JAX兼容的(可能未归一化)对数概率密度函数,它就可以很好地与PPLs集成。

谁应该使用BlackJAX?

BlackJAX应该对以下人员有吸引力:

  • 有对数概率密度函数,只需要采样;
  • 需要比通用采样器更多的功能;
  • 想在GPU上采样;
  • 想在他们的研究中构建强大的基本块;
  • 在构建概率编程语言;
  • 想了解采样算法是如何工作的。

快速入门

安装

您可以使用pip安装BlackJAX

pip install blackjax

或通过conda-forge

conda install -c conda-forge blackjax

Blackjax的夜间构建(尖端)也可以使用pip安装

pip install blackjax-nightly

BlackJAX 使用纯 Python 编写,但依赖于 JAX 的 XLA。默认情况下,与 BlackJAX 一起安装的 JAX 版本将使您的代码仅在 CPU 上运行。如果您想使用 BlackJAX 在 GPU/TPU 上运行,我们建议您按照 以下说明 安装带有相关硬件加速支持的 JAX。

示例

让我们看看一个使用 NUTS 进行采样的简单示例。

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
    logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
    return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for step in range(100):
    nuts_key = jax.random.fold_in(rng_key, step)
    state, _ = nuts.step(nuts_key, state)

有关如何使用该库的更多示例,请参阅 文档:如何为单个或多个链编写推理循环,如何使用 Stan warmup 等。

理念

什么是BlackJAX?

BlackJAX 在“单行”框架和模块化、可自定义的库之间架起了一座桥梁。

用户可以通过几行代码导入库并与健壮、经过良好测试且性能优良的采样器进行交互。这些采样器针对 PPL 开发人员或只需要一个能正常工作的采样器的人士。

但 BlackJAX 的真正优势在于其内部结构和它们如何被用于快速实验现有或新的采样方案。这一低级结构暴露了推理算法的构建块:积分器、建议、动量生成器等,并使得将它们组合起来构建新算法变得容易。它通过提供健壮、高性能和可重用的代码来加速采样算法的研究。

为什么选择 BlackJAX?

采样算法通常被集成到 PPL 中,而没有从框架的其余部分解耦,这使得它们难以被那些不需要建模语言来构建其 logpdf 的人使用。它们的实现通常是单体化的,无法重用算法的某些部分来构建自定义内核。BlackJAX 解决了这两个问题。

它是如何工作的?

BlackJAX 允许构建任意复杂的算法,因为它围绕一个非常通用的模式构建。任何接受状态并返回状态的函数都是一个转换内核,它被实现为

new_state, info =  kernel(rng_key, state)

内核是无状态的函数,并且所有都遵循相同的 API;状态和与转换相关的信息被分别返回。因此,它们可以很容易地组合和交换。我们通过闭包而不是传递参数来专门化这些内核。

贡献

请参阅我们的 简短指南

引用 Blackjax

引用此存储库

@misc{cabezas2024blackjax,
      title={BlackJAX: Composable {B}ayesian inference in {JAX}},
      author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},
      year={2024},
      eprint={2402.10797},
      archivePrefix={arXiv},
      primaryClass={cs.MS}
}

在上面的 bibtex 条目中,名称按字母顺序排列,版本号应该是 main 分支上的最后一个标签。

致谢

NUTS 实现的一些细节在很大程度上受到了 Numpyro 的启发。

项目详情


发布历史 发布通知 | RSS 源

下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于 安装包 的信息。

源分布

blackjax-nightly-1.1.1.post7.tar.gz (4.6 MB 查看哈希值)

上传时间

构建分布

blackjax_nightly-1.1.1.post7-py3-none-any.whl (4.6 MB 查看哈希值)

上传时间 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF 赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误日志 StatusPage StatusPage 状态页面