跳转到主要内容

Stan或PyMC模型示例

项目描述

nutpie:贝叶斯后验的快速采样器

安装

nutpie可以使用Conda或Mamba从conda-forge安装

mamba install -c conda-forge nutpie

或者使用pip

pip install nutpie

要从源安装,安装Rust编译器和maturin,然后

maturin develop --release

如果您想使用某些数学函数的夜间SIMD实现,切换到Rust夜间版本,然后在nutpie目录中安装带有simd_support功能的包

rustup override set nightly
maturin develop --release --features=simd_support

与PyMC一起使用

首先,需要安装PyMC和Numba,例如使用

mamba install -c conda-forge pymc numba

我们需要创建一个模型

import pymc as pm
import numpy as np
import nutpie
import pandas as pd
import seaborn as sns

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as pymc_model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )

然后编译此模型并从后验中采样

compiled_model = nutpie.compile_pymc_model(pymc_model)
trace_pymc = nutpie.sample(compiled_model)

trace_pymc现在包含一个ArviZ InferenceData对象,包括采样统计信息和上面定义的变量的后验。

我们还可以以非阻塞方式控制采样器

# The sampler will now run the the background
sampler = nutpie.sample(compiled_model, blocking=False)

# Pause and resume the sampling
sampler.pause()
sampler.resume()

# Wait for the sampler to finish (up to timeout seconds)
# sampler.wait(timeout=0.1)

# or we can also abort the sampler (and return the incomplete trace)
incomplete_trace = sampler.abort()

# or cancel and discard all progress:
sampler.cancel()

与Stan一起使用

为了从Stan模型中采样,需要安装bridgestan。有一个可用的pip包,但当前无法使用Conda安装。

pip install bridgestan

当我们使用pip安装nutpie时,还可以指定我们希望使用Stan模型的可选依赖项。

pip install 'nutpie[stan]'

此外,还需要一个C++编译器。有关详细信息,请参阅Stan文档

然后我们可以使用nutpie编译Stan模型并进行采样。

import nutpie

code = """
data {
    real mu;
}
parameters {
    real x;
}
model {
    x ~ normal(mu, 1);
}
"""

compiled = nutpie.compile_stan_model(code=code)
# Provide data
compiled = compiled.with_data(mu=3.)
trace = nutpie.sample(compiled)

优势

nutpie使用nuts-rs,这是一个用Rust编写的库,实现了PyMC和Stan中的NUTS,但采用了略有不同的质量矩阵调整方法。它通常在每个梯度评估中产生更高的有效样本大小,并且倾向于更快收敛,梯度评估次数更少。

项目详情


下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源代码分发

nutpie-0.13.2.tar.gz (184.0 kB 查看哈希值)

上传时间 源代码

构建的发行版

nutpie-0.13.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl (1.5 MB 查看哈希值)

上传时间 PyPy manylinux: glibc 2.28+ x86-64

nutpie-0.13.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl (1.6 MB 查看哈希值)

上传时间 PyPy manylinux: glibc 2.28+ ARM64

nutpie-0.13.2-cp312-none-win_amd64.whl (1.3 MB 查看哈希值)

上传时间 CPython 3.12 Windows x86-64

nutpie-0.13.2-cp312-cp312-manylinux_2_28_x86_64.whl (1.5 MB 查看哈希值)

上传时间 CPython 3.12 manylinux: glibc 2.28+ x86-64

nutpie-0.13.2-cp312-cp312-manylinux_2_28_aarch64.whl (1.6 MB 查看哈希值)

上传时间 CPython 3.12 manylinux: glibc 2.28+ ARM64

nutpie-0.13.2-cp312-cp312-macosx_11_0_arm64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.12 macOS 11.0+ ARM64

nutpie-0.13.2-cp312-cp312-macosx_10_12_x86_64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.12 macOS 10.12+ x86-64

nutpie-0.13.2-cp311-none-win_amd64.whl (1.3 MB 查看哈希值)

上传时间 CPython 3.11 Windows x86-64

nutpie-0.13.2-cp311-cp311-manylinux_2_28_x86_64.whl (1.5 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.28+ x86-64

nutpie-0.13.2-cp311-cp311-manylinux_2_28_aarch64.whl (1.6 MB 查看哈希值)

上传时间 CPython 3.11 manylinux: glibc 2.28+ ARM64

nutpie-0.13.2-cp311-cp311-macosx_11_0_arm64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.11 macOS 11.0+ ARM64

nutpie-0.13.2-cp311-cp311-macosx_10_12_x86_64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.11 macOS 10.12+ x86-64

nutpie-0.13.2-cp310-none-win_amd64.whl (1.3 MB 查看哈希值)

上传时间 CPython 3.10 Windows x86-64

nutpie-0.13.2-cp310-cp310-manylinux_2_28_x86_64.whl (1.5 MB 查看哈希值)

上传时间 CPython 3.10 manylinux: glibc 2.28+ x86-64

nutpie-0.13.2-cp310-cp310-manylinux_2_28_aarch64.whl (1.6 MB 查看哈希值)

上传时间 CPython 3.10 manylinux: glibc 2.28+ ARM64

nutpie-0.13.2-cp310-cp310-macosx_11_0_arm64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.10 macOS 11.0+ ARM64

nutpie-0.13.2-cp310-cp310-macosx_10_12_x86_64.whl (1.4 MB 查看哈希值)

上传时间 CPython 3.10 macOS 10.12+ x86_64

由以下支持