跳转到主要内容

Aesara的PPL工具

项目描述

Tests Status Coverage Join the chat at https://gitter.im/aesara-devs/aeppl

aeppl为用Aesara编写的提供工具。

特性

  • 将包含Aesara RandomVariable的图转换为联合对数似然图

  • RandomVariable映射到约束支持空间到非约束空间(例如,扩展实数)的转换,以及自动在整个图中应用这些转换的重写

  • 用于遍历和转换包含RandomVariable的图的工具

  • RandomVariable感知的格式化和LaTeX输出

示例

使用aeppl,可以从包含Aesara RandomVariable的图创建联合对数似然图

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint


# A simple scale mixture model
S_rv = at.random.invgamma(0.5, 0.5)
Y_rv = at.random.normal(0.0, at.sqrt(S_rv))

# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)

对数概率图是标准的Aesara图,因此我们可以使用它们来计算值

logprob_fn = aesara.function([y, s], logprob)

logprob_fn(-0.5, 1.0)
# array(-2.46287705)

图也可以进行美化打印

from aeppl import pprint, latex_pprint


# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a

print(latex_pprint(Y_rv))
# \begin{equation}
#   \begin{gathered}
#     b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\,  \in \mathbb{R}
#     \\
#     a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\,  \in \mathbb{R}
#   \end{gathered}
#   \\
#   a
# \end{equation}

# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
#         ((-0.9189385175704956 +
#           switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
#         -inf) +
#  ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))

还可以计算一些从随机变量派生出来的项的联合对数概率

# Create a switching model from a Bernoulli distributed index
Z_rv = at.random.normal([-100, 100], 1.0, name="Z")
I_rv = at.random.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
#  ((-0.9189385332046727 + (-0.5 * (((m - [-100  100][a]) / [1. 1.][a]) ** 2))) -
#   log([1. 1.][a])))

安装

可以通过PyPI使用pip安装最新版本的

pip install aeppl

可以通过GitHub使用pip安装的当前开发分支

pip install git+https://github.com/aesara-devs/aeppl

项目详情


发布历史 发布通知 | RSS订阅

下载文件

下载适合您平台的文件。如果您不确定选择哪个,请了解更多关于安装包的信息。

源分布

aeppl-nightly-0.0.40.tar.gz (68.8 kB 查看散列值)

上传时间

构建分布

aeppl_nightly-0.0.40-py3-none-any.whl (58.2 kB 查看散列值)

上传时间 Python 3

支持者

AWS AWS 云计算和安全赞助商 Datadog Datadog 监控 Fastly Fastly CDN Google Google 下载分析 Microsoft Microsoft PSF赞助商 Pingdom Pingdom 监控 Sentry Sentry 错误记录 StatusPage StatusPage 状态页面