PPL tools for Aesara
Project description
aeppl provides tools for a[e]PPL written in Aesara.
Features
Convert graphs containing Aesara RandomVariables into joint log-probability graphs
Transforms for RandomVariables that map constrained support spaces to unconstrained spaces (e.g. the extended real numbers), and a rewrite that automatically applies these transformations throughout a graph
Tools for traversing and transforming graphs containing RandomVariables
RandomVariable-aware pretty printing and LaTeX output
Examples
Using aeppl, one can create a joint log-probability graph from a graph containing Aesara RandomVariables:
import aesara
from aesara import tensor as at
from aeppl import joint_logprob, pprint
srng = at.random.RandomStream()
# A simple scale mixture model
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))
# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)
Log-probability graphs are standard Aesara graphs, so we can compute values with them:
logprob_fn = aesara.function([y, s], logprob)
logprob_fn(-0.5, 1.0)
# array(-2.46287705)
Graphs can also be pretty printed:
from aeppl import pprint, latex_pprint
# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a
print(latex_pprint(Y_rv))
# \begin{equation}
# \begin{gathered}
# b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\, \in \mathbb{R}
# \\
# a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\, \in \mathbb{R}
# \end{gathered}
# \\
# a
# \end{equation}
# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding
logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)
print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
# ((-0.9189385175704956 +
# switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
# -inf) +
# ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))
Joint log-probabilities can be computed for some terms that are derived from RandomVariables, as well:
# Create a switching model from a Bernoulli distributed index
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")
M_rv = Z_rv[I_rv]
M_rv.name = "M"
# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)
logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)
print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
# ((-0.9189385332046727 + (-0.5 * (((m - [-100 100][a]) / [1. 1.][a]) ** 2))) -
# log([1. 1.][a])))
Installation
The latest release of aeppl can be installed from PyPI using pip:
pip install aeppl
The current development branch of aeppl can be installed from GitHub, also using pip:
pip install git+https://github.com/aesara-devs/aeppl
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file aeppl-nightly-0.1.0.dev20230120.tar.gz
.
File metadata
- Download URL: aeppl-nightly-0.1.0.dev20230120.tar.gz
- Upload date:
- Size: 66.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.11.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a85998e5dbaf75ed4ed51bc1fda87b2fd9b024e44c7fc2baf843591f1e8addb7 |
|
MD5 | 330c0c133a66f4121adf492877898e28 |
|
BLAKE2b-256 | b80c866aa0009d30255dedef78e94dbe7058f26e30011acad2accf477168a51a |