Stan或PyMC模型示例
项目描述
nutpie:贝叶斯后验的快速采样器
安装
nutpie可以使用Conda或Mamba从conda-forge安装
mamba install -c conda-forge nutpie
或者使用pip
pip install nutpie
要从源安装,安装Rust编译器和maturin,然后
maturin develop --release
如果您想使用某些数学函数的夜间SIMD实现,切换到Rust夜间版本,然后在nutpie目录中安装带有simd_support
功能的包
rustup override set nightly
maturin develop --release --features=simd_support
与PyMC一起使用
首先,需要安装PyMC和Numba,例如使用
mamba install -c conda-forge pymc numba
我们需要创建一个模型
import pymc as pm
import numpy as np
import nutpie
import pandas as pd
import seaborn as sns
# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}
# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as pymc_model:
intercept = pm.Normal("intercept", sigma=10)
# County effects
raw = pm.ZeroSumNormal("county_raw", dims="county")
sd = pm.HalfNormal("county_sd")
county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")
# Global floor effect
floor_effect = pm.Normal("floor_effect", sigma=2)
# County:floor interaction
raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
sd = pm.HalfNormal("county_floor_sd")
county_floor_effect = pm.Deterministic(
"county_floor_effect", raw * sd, dims="county"
)
mu = (
intercept
+ county_effect[county_idx]
+ floor_effect * data.floor.values
+ county_floor_effect[county_idx] * data.floor.values
)
sigma = pm.HalfNormal("sigma", sigma=1.5)
pm.Normal(
"log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
)
然后编译此模型并从后验中采样
compiled_model = nutpie.compile_pymc_model(pymc_model)
trace_pymc = nutpie.sample(compiled_model)
trace_pymc
现在包含一个ArviZ InferenceData
对象,包括采样统计信息和上面定义的变量的后验。
我们还可以以非阻塞方式控制采样器
# The sampler will now run the the background
sampler = nutpie.sample(compiled_model, blocking=False)
# Pause and resume the sampling
sampler.pause()
sampler.resume()
# Wait for the sampler to finish (up to timeout seconds)
# sampler.wait(timeout=0.1)
# or we can also abort the sampler (and return the incomplete trace)
incomplete_trace = sampler.abort()
# or cancel and discard all progress:
sampler.cancel()
与Stan一起使用
为了从Stan模型中采样,需要安装bridgestan
。有一个可用的pip包,但当前无法使用Conda安装。
pip install bridgestan
当我们使用pip安装nutpie时,还可以指定我们希望使用Stan模型的可选依赖项。
pip install 'nutpie[stan]'
此外,还需要一个C++编译器。有关详细信息,请参阅Stan文档。
然后我们可以使用nutpie编译Stan模型并进行采样。
import nutpie
code = """
data {
real mu;
}
parameters {
real x;
}
model {
x ~ normal(mu, 1);
}
"""
compiled = nutpie.compile_stan_model(code=code)
# Provide data
compiled = compiled.with_data(mu=3.)
trace = nutpie.sample(compiled)
优势
nutpie使用nuts-rs
,这是一个用Rust编写的库,实现了PyMC和Stan中的NUTS,但采用了略有不同的质量矩阵调整方法。它通常在每个梯度评估中产生更高的有效样本大小,并且倾向于更快收敛,梯度评估次数更少。
项目详情
关闭
nutpie-0.13.2.tar.gz的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | f14282e2ac045c67a9b262a865b02a243178c55b541b236b21dfcb0c3678bcea |
|
MD5 | 5fdb39f0d38b87faf0348fb49f0f77ab |
|
BLAKE2b-256 | ff72830b4d56961f2759641277476c93715d2b01435a376f68fbed659b2c86d0 |
关闭
nutpie-0.13.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 2100024275ec6ba6de899188a3a2111f4b68aee7bfdbd4e4eb02ed4c922a9f22 |
|
MD5 | b179bbc6e9a6fa70b2cf7130ccac59d1 |
|
BLAKE2b-256 | ff2803b9b8362d55b10bd325b3ed0870f81995239487299f323f4d3664ec9988 |
关闭
nutpie-0.13.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | db240a317b1ded7eddf2ca8e2b4bcfcdbd4624256655aac61625c8f7d5ca39d0 |
|
MD5 | 138cae150e2a303f3b2a0817e4f2c3a6 |
|
BLAKE2b-256 | a7fe67a545daebbe468b8139067dd1172cc145769f5f0784b69768749588a157 |
关闭
nutpie-0.13.2-cp312-none-win_amd64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 5b6f45e2e475eee1519f18b6cbcd56ef225dbcaeb6f35e248d829467097ab385 |
|
MD5 | 010df8bc27e4acf7d2d1a051a1ea43a0 |
|
BLAKE2b-256 | cbb1b270984c0df52991f6e09428c49a8f84df68287720430c410c4942e0f2fd |
关闭
nutpie-0.13.2-cp312-cp312-manylinux_2_28_x86_64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 6d29babf3773544692153799b3579f9de1e084a06fd2dcc851e97bef4c92768b |
|
MD5 | f69b966aea1a36cfa84592aed2425631 |
|
BLAKE2b-256 | 554c0f801fe8d421e5ad89b1430bfad9b57fa53c7df4ad41d943476be2894acf |
关闭
nutpie-0.13.2-cp312-cp312-manylinux_2_28_aarch64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | e1419e53a5ce3bfba39157cb1381eb18f1835bd1b73312d485e1f543f9ce3748 |
|
MD5 | 2f035442125a3b3c9af00dd9f2568783 |
|
BLAKE2b-256 | 102293008b80652837f148369490677198d7bf32a35f33e859df6b5f1a8d4d12 |
关闭
nutpie-0.13.2-cp312-cp312-macosx_11_0_arm64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 57b6f6640996d88b290285acdcf7978bf9f6257c2a80d38eb5d1903e11bb0301 |
|
MD5 | 624be5d3fdb5adce5c7eeb755f0816a4 |
|
BLAKE2b-256 | d364111aa857623e6c6f797a12d03b600681d9f0ba7fe5ecc60f0fb19d405450 |
关闭
nutpie-0.13.2-cp312-cp312-macosx_10_12_x86_64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 1656a4e45981db30d9ca850e889c10ac69c3e327a994607924c2db1dcefb49c7 |
|
MD5 | 6d961cfb4cebf0cdc85de0d428380374 |
|
BLAKE2b-256 | 185f08465d154b674cf89e921f4cfb50aabbc2cc53f40d67883237467c0cf040 |
关闭
nutpie-0.13.2-cp311-none-win_amd64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | d7d297a975737ca997890cae284adca74e429567503596cbf66a37640faf4f10 |
|
MD5 | 87d862759187bd98ac77c9270ab40e17 |
|
BLAKE2b-256 | 6e2632ef7579f02d5f8fec72e7b20d35b7f4badfd18da4ca72225c83e96eaf72 |
关闭
nutpie-0.13.2-cp311-cp311-manylinux_2_28_x86_64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | be1635cdd6ec19cc541e212ee95e11288dda7a234a2ae7f70c2c91fdaa677fe0 |
|
MD5 | 50593ace02a0be2d40071ba1c02d3420 |
|
BLAKE2b-256 | c174de2d67427a2ba1083074b6c6335b2260db03dbff110cffdccbf038c8ff36 |
关闭
nutpie-0.13.2-cp311-cp311-manylinux_2_28_aarch64.whl的哈希值
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 1a7a5e7012976327485349b581ae762cd6e60bb1805f9d323e0eed2d945c73a3 |
|
MD5 | df860d0363ee086925a6bab976fb7670 |
|
BLAKE2b-256 | fe4a813b41a959baa6ad073ca4c4366bfb307c51dc960df5528db0138f7fba7c |
关闭
哈希值 用于 nutpie-0.13.2-cp311-cp311-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 225f17a15e33f731db43c55f821b988df2781568e2dc6f22ae9798e259386009 |
|
MD5 | 3ddd130c2538831d35a8f6b59b0b0143 |
|
BLAKE2b-256 | f274aa80e4cf65f8db0f8cd76e1b1f2c5dea277f01fcdd565e65371e67cd3b12 |
关闭
哈希值 用于 nutpie-0.13.2-cp311-cp311-macosx_10_12_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 024fb04ddcaa2ce8a2cf6864bebe68acfb68518f6199c6d3de0c6b9b49d1ac75 |
|
MD5 | a88d980d1ead14090c6adbf9c203e4a2 |
|
BLAKE2b-256 | c94a4592c00fe811491531f95ec8c5385504689058c471068bb007fc11a4bd96 |
关闭
哈希值 用于 nutpie-0.13.2-cp310-cp310-manylinux_2_28_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 0202a5b2352b065a269dd1467cacd4b9ef4020665373e4d12eede232425eaea8 |
|
MD5 | 62970723be7f567daec4e9ac9afa80c7 |
|
BLAKE2b-256 | c66eb2136595658a43aa5288aa1f48104f805a10973df77395591a1d55a1942a |
关闭
哈希值 用于 nutpie-0.13.2-cp310-cp310-manylinux_2_28_aarch64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | a7cfe73f29769f7185e677587755ba63818e9334d161a69216c8d6cefd9d66b7 |
|
MD5 | be6e57b3b78f34fdae36d4ed186e00a0 |
|
BLAKE2b-256 | 79064358bbe171c31b0c2585202c234386f094cdc80cab9514153ee286bb8460 |
关闭
哈希值 用于 nutpie-0.13.2-cp310-cp310-macosx_11_0_arm64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | b69e62c4d25e62e670ef31244e65556ed562650dfbc56a068972e177c5e5e291 |
|
MD5 | 9ee33e4b1ea7e543c553c2286881a4bf |
|
BLAKE2b-256 | f3b3c5d2bc91948920ba39841ab75f8359df5ca5fe8589a6ad0c61e61c5c530e |
关闭
哈希值 用于 nutpie-0.13.2-cp310-cp310-macosx_10_12_x86_64.whl
算法 | 哈希摘要 | |
---|---|---|
SHA256 | 4c731b6b32f51407ca973aefdcb0241c6dadfebcf47e781557344d28d346c0fa |
|
MD5 | ea136c341409b5c0b03f144b6c851dec |
|
BLAKE2b-256 | 3026629bb17d78728cba3a8089ac15688d17dfa1d65e94ed3d9bc0ccae6f4786 |