将概率模型和推理拼接在一起。
项目描述
Bayeux
将模型和采样器拼接在一起
bayeux
允许您在JAX中编写概率模型,并立即访问最先进的推理方法。该API旨在具有 简单、自描述性 和 有帮助。只需提供一个对数密度函数(该函数甚至不需要归一化),以及一个点(指定为一个 pytree),其中该对数密度是有限的。然后让 bayeux
做其余的工作!
安装
pip install bayeux-ml
快速入门
我们通过提供JAX中的对数密度来定义模型。这可以使用概率编程语言(PPL)如 numpyro、PyMC、TFP、distrax、oryx、coix 或直接在 JAX 中定义。
import bayeux as bx
import jax
normal_density = bx.Model(
log_density=lambda x: -x*x,
test_point=1.)
seed = jax.random.key(0)
opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)
阅读更多
这不是官方支持的Google产品。
项目详情
下载文件
下载适合您平台的文件。如果您不确定选择哪个,请了解有关安装包的更多信息。
源分发
bayeux_ml-0.1.14.tar.gz (27.2 kB 查看哈希值)
构建分发
bayeux_ml-0.1.14-py3-none-any.whl (42.6 kB 查看哈希值)