Probabilistic inference focused on fun
Project description
Pangolin's goal is to be the world's friendliest probabilistic programming language and to make probabilistic inference fun. It is now usable, but is still something of a research project.
Changelog
See CHANGELOG.md
API Docs
See justindomke.github.io/pangolin.
Installation / Quickstart
If you have uv installed, you can test pangolin in a temporary environment by just using --with pangolin at the command line. For example, say m ~ normal(0,1) and s ~ lognormal(0,1) and x ~ normal(m, s) is observed to have value x=3.5. Then you can calculate the posterior mean of m and s like so:
$ uv run --with pangolin python
Python 3.14.3
Type "help", "copyright", "credits" or "license" for more information.
>>> from pangolin import interface as pi
>>> from pangolin.blackjax import E
>>> m = pi.normal(0,1)
>>> s = pi.lognormal(0,1)
>>> x = pi.normal(m, s)
>>> E([m, s], x, 3.5)
[Array(0.70249635, dtype=float32), Array(2.8581674, dtype=float32)]
More broadly, pangolin is on pypi so you can install it by using pip install pangolin or uv add pangolin or whatever. See INSTALL.md for details.
Why?
At a high level, Pangolin has two goals:
-
To make things simple for end users who just want to do inference, while still taking full advantage of modern hardware (GPUs).
-
To make things simple for researchers who want to develop new inference algorithms, develop new ways of specifying probabilistic models, share models between different backends (JAX / PyTorch), benchmark inference algorithms written in different languages, etc.
Why (for end users)?
For end-users, Pangolin tries to provide an interface that is simple and explicit. In particular:
-
Gradual enhancement. Easy things should be really easy. More complex features should be easy to discover. Steep learning curves should be avoided.
-
Small API surface. The set of abstractions the user needs to learn should be as small as possible.
-
Explicitness. Many modern PPLs (e.g. Pyro / NumPyro / PyMC / Orxy) lean heavily on NumPy's broadcasting semantics. This looks very nice in simple cases, but becomes confusing in complex cases. In Pangolin, by default, only a very limited amount of broadcasting is allowed. (Though this is configurable.) Instead of implicit broadcasting, in Pangolin, users should use an explicit
vmaptransformation, inspired byjax.vmap. If you seex = vmap(normal, [0, None])(a, b)orx = vfor(lambda i: normal(a[i], b)), that means thatamust be one-dimensional,bmust be scalar, andxmust be one-dimensional. Similarly, many modern PPLs inherit their indexing behavior from NumPy, which combines broadcasting with lots of other special cases and is legendarily complicated. Pangolin uses ultra-simple and ultra-legible full-orthogonal indexing. If you seeu = z[x,y]then you know thatz.ndim == 2andu.ndim == x.ndim + y.ndimalways. More complex cases can still be handled withvmap. All this code more self-documenting and predictable. -
Graceful interop. As much as possible, the system should feel like a natural part of the broader ecosystem, rather than a "new language". In particular, Pangolin tries to avoid several oddities common in other modern PPLs:
-
No "sample" statements or string labels. In Pyro or NumPyro you write
z = sample('z', Normal(0, 1)). In Pangolin you just writez = normal(0, 1). If you want to refer tozlater, you use a reference to the resultingRVobject, e.g. by writingE(z)to get the expected value ofz. You can organize random variables into (recursive) lists or tuples or dictionaries however you want. For example, ifxy, andzare scalarRVs, thenE([x, {'alice': y, 'bob': z}])will return a list where the first element is a float, and the second element is a dictionary with keys'alice'and'bob', each of which map to a float. -
No attaching data to random variables with "obs" statements. In Pyro or NumPyro or PyMC, if a random variable
xis observed, you need to write something likex = sample('x', Normal(z, 1), obs=x_obs). In Pangolin, you always just writex = normal(z, 1). You decide if you want to condition onzat the inference stage, e.g. by usingE(z, x, x_obs)to estimate the expected value ofzconditioning onx=x_obs. This is how it works in math, after all. -
No "model" objects. In most (all?) other PPLs, you create a "model" object, and then you query it to exact information about random variables. In Pangolin, you just manipulate random variables, with no additional layer of abstraction. This is also how it works in math.
-
-
In Pangolin you can see the internal representation. After building a model, you can call
print_upstreamto see the internal representation, with the parents and shapes of all random variables.
Why (for researchers)?
Pangolin is extremely modular. It's build around a simple internal representation (IR) in which there are only two types of objects: An Op represents a conditional distribution or deterministic function, while an RV contains a single Op and a list of parent RV. Primitives to make evaluation efficient on modern hardware (e.g. VMap or Scan) wrap individuals Ops. That's basically all there is to it.
All other parts of Pangolin are decoupled: They only depend on the IR, not on each other. For example, the interface offers a friendly way for users to specify models, with optional broadcasting, program transformations, and so on. Internally, this is quite complicateded. But it just produces models in the IR. The different backends only look at the IR, and don't even know that the interface layer exists.
This makes many things easy that are typically quite difficult in modern PPLs:
-
Say you want to create a new inference algorithm that programatically inspects the model. That's easy, because the IR is just a static graph of random variables.
-
Say you want to create a probabilistic model and share it with collaborators, some of whom use JAX and some of whom use PyTorch. That's fine. The former group can use the JAX backend while the latter group use the torch backend.
-
Say you want to create a new "backend" that will do inference using a different array computing framework instead of JAX or PyTorch. This is pretty easy. The torch backend is around 1000 lines. The (more capable) JAX backend is around 2000 lines. The blackjax interface is 400 lines.
-
Say you hate Pangolin's interface. That's fine. Make a new one! As long as you produce models into the Pangolin IR, you can still use the existing backends.
-
In the future, we hope to make the IR language independent, so interfaces and backends could be in other languages, e.g. R or Julia. (This is possible in principle now, but could be made easier.)
Examples and comparisons
Simple "probabilistic calculator"
If z ~ normal(0,2) and x ~ normal(0,6) then what is E[z | x = -10]?
from pangolin import interface as pi
from pangolin.blackjax import E
z = pi.normal(0,2)
x = pi.normal(x,6)
print(E(z, x, -10.0))
Here is the same model in other PPLs. (Or see calculator-ppls.ipynb.)
PyMC
import pymc as pm
with pm.Model():
z = pm.Normal('z', 0, 2)
x = pm.Normal('x', z, 6, observed=-10)
trace = pm.sample(chains=1)
z_samps = trace.posterior['z'].values
print(np.mean(z_samps))
Pyro
import pyro
import torch
def model():
z = pyro.sample('z', pyro.distributions.Normal(0, 2))
x = pyro.sample('x', pyro.distributions.Normal(z, 6), obs=torch.tensor(-10.0))
nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
print(np.mean(z_samps))
NumPyro
import numpyro
import jax
import jax.numpy as jnp
def model():
z = numpyro.sample('z', numpyro.distributions.Normal(0, 2))
x = numpyro.sample('x', numpyro.distributions.Normal(z, 6), obs=-10)
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
print(np.mean(z_samps))
JAGS
import pyjags
model_code = """
model {
z ~ dnorm(0, 1/2^2)
x ~ dnorm(z, 1/6^2)
}
"""
model = pyjags.Model(
code=model_code,
data={'x': -10},
chains=1,
adapt=500
)
samples = model.sample(1000, ['z'])
z_samps = samples['z'].flatten()
print(np.mean(z_samps))
Stan
import cmdstanpy
import tempfile
from pathlib import Path
stan_code = """
data {
real x;
}
parameters {
real z;
}
model {
z ~ normal(0, 2);
x ~ normal(z, 6);
}
"""
with tempfile.TemporaryDirectory() as tmpdir:
stan_file = Path(tmpdir) / "calculator_model.stan"
stan_file.write_text(stan_code)
model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))
fit = model.sample(
data={'x': -10.0},
chains=1,
iter_warmup=500,
iter_sampling=1000,
seed=42
)
z_samps = fit.stan_variable('z')
print(np.mean(z_samps))
Beta-Bernoulli model
This is arguably the simplest Bayesian model. If you've seen a bunch of coinflips from a bent coin, what is the true bias? To start, generate synthetic data.
# synthetic data
import numpy as np
np.random.seed(67)
z_true = 0.7
N = 20
x_obs = np.random.binomial(1, z_true, N)
# create model
import pangolin
from pangolin import interface as pi
z = pi.beta(2,2)
x = pi.vmap(pi.bernoulli, None, N)(z)
# do inference
z_samps = pangolin.blackjax.sample(z, x, x_obs) # p(z | x = x_obs)
# plot
import seaborn as sns
sns.histplot(z_samps, binrange=[0,1])
Here is the same model in other PPLs. (Or see beta-bernoulli-ppls.ipynb.)
PyMC
import pymc as pm
with pm.Model() as coin_model:
z = pm.Beta('z', alpha=2, beta=2)
x = pm.Bernoulli('x', z, observed=x_obs)
trace = pm.sample(chains=1)
z_samps = trace.posterior['z'].values
print(np.mean(z_samps), np.std(z_samps))
Pyro
import pyro
import torch
x_obs_torch = torch.tensor(x_obs, dtype=torch.float)
def model():
z = pyro.sample('z', pyro.distributions.Beta(2.0, 2.0))
with pyro.plate('N', N):
x = pyro.sample('x', pyro.distributions.Bernoulli(z), obs=x_obs_torch)
nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
print(np.mean(z_samps), np.std(z_samps))
NumPyro
import numpyro
import jax
import jax.numpy as jnp
def model():
z = numpyro.sample('z', numpyro.distributions.Beta(2.0, 2.0))
with numpyro.plate('data', N):
x = numpyro.sample('x', numpyro.distributions.Bernoulli(z), obs=x_obs)
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
print(np.mean(z_samps), np.std(z_samps))
JAGS
import pyjags
model_code = """
model {
z ~ dbeta(2, 2)
for (i in 1:N) {
x[i] ~ dbern(z)
}
}
"""
model = pyjags.Model(
code=model_code,
data={'N': N, 'x': x_obs.tolist()},
chains=1,
adapt=500
)
samples = model.sample(1000, ['z'])
z_samps = samples['z'].flatten()
print(np.mean(z_samps), np.std(z_samps))
Stan
import cmdstanpy
import tempfile
from pathlib import Path
stan_code = """
data {
int<lower=0> N;
array[N] int<lower=0, upper=1> x;
}
parameters {
real<lower=0, upper=1> z;
}
model {
z ~ beta(2, 2);
x ~ bernoulli(z);
}
"""
with tempfile.TemporaryDirectory() as tmpdir:
stan_file = Path(tmpdir) / "coin_model.stan"
stan_file.write_text(stan_code)
model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))
fit = model.sample(
data={'N': N, 'x': x_obs},
chains=1,
iter_warmup=500,
iter_sampling=1000,
seed=42
)
z_samps = fit.stan_variable('z')
print(np.mean(z_samps), np.std(z_samps))
Eight-schools
Bayesian inference on the classic 8-schools model:
# setup
import numpy as np
N = 8
stddevs = np.array([15, 10, 16, 11, 9, 11, 10, 18])
x_obs = np.array([28, 8, -3, 7, -1, 1, 18, 12])
# inference
import pangolin
from pangolin import interface as pi
mu = pi.normal(0,10)
tau = pi.lognormal(0,5)
z = pi.vmap(pi.normal, None, N)(mu, tau)
x = pi.vmap(pi.normal)(z, stddevs)
z_samps = pangolin.blackjax.sample(z, x, x_obs, niter=10000)
# plot
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
sns.swarmplot(np.array(z_samps)[:,::50].T,s=2,zorder=0)
plt.xlabel('school')
plt.ylabel('treatment effect')
Here is the same model in other PPLs. (Or see eight-schools-ppls.ipynb.)
PyMC
import pymc as pm
with pm.Model():
mu = pm.Normal('mu', 0, 10)
tau = pm.LogNormal('tau', 0, 5)
z = pm.Normal('z', mu, tau, size=N)
x = pm.Normal('x', z, stddevs, observed=x_obs)
trace = pm.sample(draws=10000, chains=1)
z_samps = trace.posterior['z'].values[0,:,:]
Pyro
import pyro
import torch
stddevs_torch = torch.tensor(stddevs)
x_obs_torch = torch.tensor(x_obs)
def model():
mu = pyro.sample('mu', pyro.distributions.Normal(0, 10))
tau = pyro.sample('tau', pyro.distributions.LogNormal(0, 5))
with pyro.plate('N', N):
z = pyro.sample('z', pyro.distributions.Normal(mu, tau))
x = pyro.sample('x', pyro.distributions.Normal(z, stddevs_torch), obs=x_obs_torch)
nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
NumPyro
import numpyro
import jax
stddevs_torch = torch.tensor(stddevs)
x_obs_torch = torch.tensor(x_obs)
def model():
mu = numpyro.sample('mu', numpyro.distributions.Normal(0, 10))
tau = numpyro.sample('tau', numpyro.distributions.LogNormal(0, 5))
with numpyro.plate('N',N):
z = numpyro.sample('z', numpyro.distributions.Normal(mu, tau))
x = numpyro.sample('x', numpyro.distributions.Normal(z, stddevs), obs=x_obs)
nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
JAGS
import pyjags
model_code = """
model {
mu ~ dnorm(0, 1/10^2)
tau ~ dlnorm(0, 1/5^2)
for (i in 1:N) {
z[i] ~ dnorm(mu, 1/tau^2)
x[i] ~ dnorm(z[i], 1/stddevs[i]^2)
}
}
"""
model = pyjags.Model(
code=model_code,
data={'N': N, 'stddevs': stddevs.tolist(), 'x': x_obs.tolist()},
chains=1,
adapt=5000
)
samples = model.sample(100000, ['z'])
z_samps = np.array(samples['z'])[:,:,0].T
Stan
import cmdstanpy
import tempfile
from pathlib import Path
stan_code = """
data {
int<lower=0> N;
array[N] real x;
array[N] real stddevs;
}
parameters {
real mu;
real<lower=0> tau;
array[N] real z;
}
model {
mu ~ normal(0, 10);
tau ~ lognormal(0, 5);
for (i in 1:N) {
z[i] ~ normal(mu, tau);
x[i] ~ normal(z[i], stddevs[i]);
}
}
"""
with tempfile.TemporaryDirectory() as tmpdir:
stan_file = Path(tmpdir) / "8schools_model.stan"
stan_file.write_text(stan_code)
model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))
fit = model.sample(
data={'N': N, 'stddevs': stddevs, 'x': x_obs},
chains=1,
iter_warmup=500,
iter_sampling=1000,
seed=42
)
z_samps = fit.stan_variable('z')
More examples
For more examples, take a look at the demos. Here's a recommended order:
- IR.ipynb demonstrates pangolin's internal representation of probabilistic models.
- interface.ipynb demonstrates pangolin's friendly interface and what internal representation it produces.
- 8schools.ipynb is the classic 8-schools model.
- regression.ipynb is Bayesian linear regression.
- timeseries.ipynb is a simple timeseries model.
- scan.ipynb is a Kalman-filter-esque model.
- GP-regression.ipynb is Gaussian Process regression.
- 1PL.ipynb is a simple item-response-theory model.
- 2PL.ipynb is a slightly more complex item-response-theory model.
See also
An earlier version of Pangolin is available and based on much the same ideas, except only supporting JAGS as a backend. It can be found with documentation, in the
pangolin-jags directory.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pangolin-0.0.4.tar.gz.
File metadata
- Download URL: pangolin-0.0.4.tar.gz
- Upload date:
- Size: 3.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Fedora Linux","version":"43","id":"","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8bf4f3df147cf4829a94b8324ebc64b719e050be5fcaf8886b9e877bf43b107a
|
|
| MD5 |
83ac19f9b40a2054d910bb271f2c32e0
|
|
| BLAKE2b-256 |
bb2142d5c5c51a18e4906beb45a9b8eae9b179f42ab8668e88d9e9c1b2623f3a
|
File details
Details for the file pangolin-0.0.4-py3-none-any.whl.
File metadata
- Download URL: pangolin-0.0.4-py3-none-any.whl
- Upload date:
- Size: 143.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Fedora Linux","version":"43","id":"","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
537307e61aa71f9c320e2044aa4c799668ae0b1e3606d3c920f795293dd17472
|
|
| MD5 |
ad7256a525e5c534924a4fd065cc4883
|
|
| BLAKE2b-256 |
6df8d40c629cc6c2650453cf8442082a4095b05053a5aeed42cf28f8ee36443a
|