Skip to main content

Probabilistic inference focused on fun

Project description

pangolin

Pangolin's goal is to be the world's friendliest probabilistic programming language and to make probabilistic inference fun. It is now usable, but is still something of a research project.

Changelog

See CHANGELOG.md

API Docs

See justindomke.github.io/pangolin.

Installation / Quickstart

If you have uv installed, you can test pangolin in a temporary environment by just using --with pangolin at the command line. For example, say m ~ normal(0,1) and s ~ lognormal(0,1) and x ~ normal(m, s) is observed to have value x=3.5. Then you can calculate the posterior mean of m and s like so:

$ uv run --with pangolin python
Python 3.14.3 
Type "help", "copyright", "credits" or "license" for more information.
>>> from pangolin import interface as pi
>>> from pangolin.blackjax import E
>>> m = pi.normal(0,1)
>>> s = pi.lognormal(0,1)
>>> x = pi.normal(m, s)
>>> E([m, s], x, 3.5)
[Array(0.70249635, dtype=float32), Array(2.8581674, dtype=float32)]

More broadly, pangolin is on pypi so you can install it by using pip install pangolin or uv add pangolin or whatever. See INSTALL.md for details.

Why?

At a high level, Pangolin has two goals:

  1. To make things simple for end users who just want to do inference, while still taking full advantage of modern hardware (GPUs).

  2. To make things simple for researchers who want to develop new inference algorithms, develop new ways of specifying probabilistic models, share models between different backends (JAX / PyTorch), benchmark inference algorithms written in different languages, etc.

Why (for end users)?

For end-users, Pangolin tries to provide an interface that is simple and explicit. In particular:

  • Gradual enhancement. Easy things should be really easy. More complex features should be easy to discover. Steep learning curves should be avoided.

  • Small API surface. The set of abstractions the user needs to learn should be as small as possible.

  • Explicitness. Many modern PPLs (e.g. Pyro / NumPyro / PyMC / Orxy) lean heavily on NumPy's broadcasting semantics. This looks very nice in simple cases, but becomes confusing in complex cases. In Pangolin, by default, only a very limited amount of broadcasting is allowed. (Though this is configurable.) Instead of implicit broadcasting, in Pangolin, users should use an explicit vmap transformation, inspired by jax.vmap. If you see x = vmap(normal, [0, None])(a, b) or x = vfor(lambda i: normal(a[i], b)), that means that a must be one-dimensional, b must be scalar, and x must be one-dimensional. Similarly, many modern PPLs inherit their indexing behavior from NumPy, which combines broadcasting with lots of other special cases and is legendarily complicated. Pangolin uses ultra-simple and ultra-legible full-orthogonal indexing. If you see u = z[x,y] then you know that z.ndim == 2 and u.ndim == x.ndim + y.ndim always. More complex cases can still be handled with vmap. All this code more self-documenting and predictable.

  • Graceful interop. As much as possible, the system should feel like a natural part of the broader ecosystem, rather than a "new language". In particular, Pangolin tries to avoid several oddities common in other modern PPLs:

    • No "sample" statements or string labels. In Pyro or NumPyro you write z = sample('z', Normal(0, 1)). In Pangolin you just write z = normal(0, 1). If you want to refer to z later, you use a reference to the resulting RV object, e.g. by writing E(z) to get the expected value of z. You can organize random variables into (recursive) lists or tuples or dictionaries however you want. For example, if x y, and z are scalar RVs, then E([x, {'alice': y, 'bob': z}]) will return a list where the first element is a float, and the second element is a dictionary with keys 'alice' and 'bob', each of which map to a float.

    • No attaching data to random variables with "obs" statements. In Pyro or NumPyro or PyMC, if a random variable x is observed, you need to write something like x = sample('x', Normal(z, 1), obs=x_obs). In Pangolin, you always just write x = normal(z, 1). You decide if you want to condition on z at the inference stage, e.g. by using E(z, x, x_obs) to estimate the expected value of z conditioning on x=x_obs. This is how it works in math, after all.

    • No "model" objects. In most (all?) other PPLs, you create a "model" object, and then you query it to exact information about random variables. In Pangolin, you just manipulate random variables, with no additional layer of abstraction. This is also how it works in math.

  • In Pangolin you can see the internal representation. After building a model, you can call print_upstream to see the internal representation, with the parents and shapes of all random variables.

Why (for researchers)?

Pangolin is extremely modular. It's build around a simple internal representation (IR) in which there are only two types of objects: An Op represents a conditional distribution or deterministic function, while an RV contains a single Op and a list of parent RV. Primitives to make evaluation efficient on modern hardware (e.g. VMap or Scan) wrap individuals Ops. That's basically all there is to it.

All other parts of Pangolin are decoupled: They only depend on the IR, not on each other. For example, the interface offers a friendly way for users to specify models, with optional broadcasting, program transformations, and so on. Internally, this is quite complicateded. But it just produces models in the IR. The different backends only look at the IR, and don't even know that the interface layer exists.

This makes many things easy that are typically quite difficult in modern PPLs:

  • Say you want to create a new inference algorithm that programatically inspects the model. That's easy, because the IR is just a static graph of random variables.

  • Say you want to create a probabilistic model and share it with collaborators, some of whom use JAX and some of whom use PyTorch. That's fine. The former group can use the JAX backend while the latter group use the torch backend.

  • Say you want to create a new "backend" that will do inference using a different array computing framework instead of JAX or PyTorch. This is pretty easy. The torch backend is around 1000 lines. The (more capable) JAX backend is around 2000 lines. The blackjax interface is 400 lines.

  • Say you hate Pangolin's interface. That's fine. Make a new one! As long as you produce models into the Pangolin IR, you can still use the existing backends.

  • In the future, we hope to make the IR language independent, so interfaces and backends could be in other languages, e.g. R or Julia. (This is possible in principle now, but could be made easier.)

Examples and comparisons

Simple "probabilistic calculator"

If z ~ normal(0,2) and x ~ normal(0,6) then what is E[z | x = -10]?

from pangolin import interface as pi
from pangolin.blackjax import E

z = pi.normal(0,2)
x = pi.normal(x,6)
print(E(z, x, -10.0))

Here is the same model in other PPLs. (Or see calculator-ppls.ipynb.)

PyMC
import pymc as pm

with pm.Model():
    z = pm.Normal('z', 0, 2)
    x = pm.Normal('x', z, 6, observed=-10)
    trace = pm.sample(chains=1)
    z_samps = trace.posterior['z'].values
    print(np.mean(z_samps))
Pyro
import pyro
import torch

def model():
    z = pyro.sample('z', pyro.distributions.Normal(0, 2))
    x = pyro.sample('x', pyro.distributions.Normal(z, 6), obs=torch.tensor(-10.0))

nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
print(np.mean(z_samps))
NumPyro
import numpyro
import jax
import jax.numpy as jnp

def model():
    z = numpyro.sample('z', numpyro.distributions.Normal(0, 2))
    x = numpyro.sample('x', numpyro.distributions.Normal(z, 6), obs=-10)

nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
print(np.mean(z_samps))
JAGS
import pyjags

model_code = """
model {
  z ~ dnorm(0, 1/2^2)
  x ~ dnorm(z, 1/6^2)
}
"""

model = pyjags.Model(
    code=model_code,
    data={'x': -10},
    chains=1,
    adapt=500
)

samples = model.sample(1000, ['z'])
z_samps = samples['z'].flatten()
print(np.mean(z_samps))
Stan
import cmdstanpy
import tempfile
from pathlib import Path

stan_code = """
data {
  real x;
}
parameters {
  real z;
}
model {
  z ~ normal(0, 2);
  x ~ normal(z, 6);
}
"""

with tempfile.TemporaryDirectory() as tmpdir:
    stan_file = Path(tmpdir) / "calculator_model.stan"
    stan_file.write_text(stan_code)

    model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))

    fit = model.sample(
        data={'x': -10.0},
        chains=1,
        iter_warmup=500,
        iter_sampling=1000,
        seed=42
    )
    z_samps = fit.stan_variable('z')

    print(np.mean(z_samps))

Beta-Bernoulli model

This is arguably the simplest Bayesian model. If you've seen a bunch of coinflips from a bent coin, what is the true bias? To start, generate synthetic data.

# synthetic data
import numpy as np
np.random.seed(67)
z_true = 0.7
N = 20
x_obs = np.random.binomial(1, z_true, N)

# create model
import pangolin
from pangolin import interface as pi
z = pi.beta(2,2)
x = pi.vmap(pi.bernoulli, None, N)(z)

# do inference
z_samps = pangolin.blackjax.sample(z, x, x_obs) # p(z | x = x_obs)

# plot
import seaborn as sns
sns.histplot(z_samps, binrange=[0,1])

Here is the same model in other PPLs. (Or see beta-bernoulli-ppls.ipynb.)

PyMC
import pymc as pm

with pm.Model() as coin_model:
    z = pm.Beta('z', alpha=2, beta=2)
    x = pm.Bernoulli('x', z, observed=x_obs)
    trace = pm.sample(chains=1)
    z_samps = trace.posterior['z'].values
    print(np.mean(z_samps), np.std(z_samps))
Pyro
import pyro
import torch

x_obs_torch = torch.tensor(x_obs, dtype=torch.float)

def model():
    z = pyro.sample('z', pyro.distributions.Beta(2.0, 2.0))
    with pyro.plate('N', N):
        x = pyro.sample('x', pyro.distributions.Bernoulli(z), obs=x_obs_torch)

nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
print(np.mean(z_samps), np.std(z_samps))
NumPyro
import numpyro
import jax
import jax.numpy as jnp

def model():
    z = numpyro.sample('z', numpyro.distributions.Beta(2.0, 2.0))
    with numpyro.plate('data', N):
        x = numpyro.sample('x', numpyro.distributions.Bernoulli(z), obs=x_obs)

nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
print(np.mean(z_samps), np.std(z_samps))
JAGS
import pyjags

model_code = """
model {
  z ~ dbeta(2, 2)
  for (i in 1:N) {
    x[i] ~ dbern(z)
  }
}
"""

model = pyjags.Model(
    code=model_code,
    data={'N': N, 'x': x_obs.tolist()},
    chains=1,
    adapt=500
)

samples = model.sample(1000, ['z'])
z_samps = samples['z'].flatten()
print(np.mean(z_samps), np.std(z_samps))
Stan
import cmdstanpy
import tempfile
from pathlib import Path

stan_code = """
data {
  int<lower=0> N;
  array[N] int<lower=0, upper=1> x;
}
parameters {
  real<lower=0, upper=1> z;
}
model {
  z ~ beta(2, 2);
  x ~ bernoulli(z);
}
"""

with tempfile.TemporaryDirectory() as tmpdir:
    stan_file = Path(tmpdir) / "coin_model.stan"
    stan_file.write_text(stan_code)

    model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))

    fit = model.sample(
        data={'N': N, 'x': x_obs},
        chains=1,
        iter_warmup=500,
        iter_sampling=1000,
        seed=42
    )
    z_samps = fit.stan_variable('z')
    print(np.mean(z_samps), np.std(z_samps))

Eight-schools

Bayesian inference on the classic 8-schools model:

# setup
import numpy as np
N = 8
stddevs = np.array([15, 10, 16, 11, 9, 11, 10, 18])
x_obs = np.array([28, 8, -3, 7, -1, 1, 18, 12])

# inference
import pangolin
from pangolin import interface as pi

mu  = pi.normal(0,10)
tau = pi.lognormal(0,5)
z   = pi.vmap(pi.normal, None, N)(mu, tau)
x   = pi.vmap(pi.normal)(z, stddevs)
z_samps = pangolin.blackjax.sample(z, x, x_obs, niter=10000)

# plot
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
sns.swarmplot(np.array(z_samps)[:,::50].T,s=2,zorder=0)
plt.xlabel('school')
plt.ylabel('treatment effect')

Here is the same model in other PPLs. (Or see eight-schools-ppls.ipynb.)

PyMC
import pymc as pm

with pm.Model():
    mu = pm.Normal('mu', 0, 10)
    tau = pm.LogNormal('tau', 0, 5)
    z = pm.Normal('z', mu, tau, size=N)
    x = pm.Normal('x', z, stddevs, observed=x_obs)

    trace = pm.sample(draws=10000, chains=1)
    z_samps = trace.posterior['z'].values[0,:,:]
Pyro
import pyro
import torch

stddevs_torch = torch.tensor(stddevs)
x_obs_torch = torch.tensor(x_obs)

def model():
    mu = pyro.sample('mu', pyro.distributions.Normal(0, 10))
    tau = pyro.sample('tau', pyro.distributions.LogNormal(0, 5))
    with pyro.plate('N', N):
        z = pyro.sample('z', pyro.distributions.Normal(mu, tau))
        x = pyro.sample('x', pyro.distributions.Normal(z, stddevs_torch), obs=x_obs_torch)

nuts_kernel = pyro.infer.mcmc.NUTS(model)
mcmc = pyro.infer.mcmc.MCMC(nuts_kernel, warmup_steps=500, num_samples=1000, num_chains=1)
mcmc.run()
z_samps = mcmc.get_samples()['z'].numpy()
NumPyro
import numpyro
import jax

stddevs_torch = torch.tensor(stddevs)
x_obs_torch = torch.tensor(x_obs)

def model():
    mu = numpyro.sample('mu', numpyro.distributions.Normal(0, 10))
    tau = numpyro.sample('tau', numpyro.distributions.LogNormal(0, 5))
    with numpyro.plate('N',N):
        z = numpyro.sample('z', numpyro.distributions.Normal(mu, tau))
        x = numpyro.sample('x', numpyro.distributions.Normal(z, stddevs), obs=x_obs)

nuts_kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
mcmc.run(jax.random.PRNGKey(42))
z_samps = mcmc.get_samples()['z']
JAGS
import pyjags

model_code = """
model {
  mu ~ dnorm(0, 1/10^2)
  tau ~ dlnorm(0, 1/5^2)
  for (i in 1:N) {
    z[i] ~ dnorm(mu, 1/tau^2)
    x[i] ~ dnorm(z[i], 1/stddevs[i]^2)
  }
}
"""

model = pyjags.Model(
    code=model_code,
    data={'N': N, 'stddevs': stddevs.tolist(), 'x': x_obs.tolist()},
    chains=1,
    adapt=5000
)

samples = model.sample(100000, ['z'])
z_samps = np.array(samples['z'])[:,:,0].T
Stan
import cmdstanpy
import tempfile
from pathlib import Path

stan_code = """
data {
  int<lower=0> N;
  array[N] real x;
  array[N] real stddevs;
}
parameters {
  real mu;
  real<lower=0> tau;
  array[N] real z;
}
model {
  mu ~ normal(0, 10);
  tau ~ lognormal(0, 5);
  for (i in 1:N) {
    z[i] ~ normal(mu, tau);
    x[i] ~ normal(z[i], stddevs[i]);
  }
}
"""

with tempfile.TemporaryDirectory() as tmpdir:
    stan_file = Path(tmpdir) / "8schools_model.stan"
    stan_file.write_text(stan_code)

    model = cmdstanpy.CmdStanModel(stan_file=str(stan_file))

    fit = model.sample(
        data={'N': N, 'stddevs': stddevs, 'x': x_obs},
        chains=1,
        iter_warmup=500,
        iter_sampling=1000,
        seed=42
    )
    z_samps = fit.stan_variable('z')

More examples

For more examples, take a look at the demos. Here's a recommended order:

See also

An earlier version of Pangolin is available and based on much the same ideas, except only supporting JAGS as a backend. It can be found with documentation, in the pangolin-jags directory.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pangolin-0.0.3.tar.gz (3.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pangolin-0.0.3-py3-none-any.whl (143.5 kB view details)

Uploaded Python 3

File details

Details for the file pangolin-0.0.3.tar.gz.

File metadata

  • Download URL: pangolin-0.0.3.tar.gz
  • Upload date:
  • Size: 3.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Fedora Linux","version":"43","id":"","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for pangolin-0.0.3.tar.gz
Algorithm Hash digest
SHA256 4a6eed356febf5474fdd693872fd3cba7a8e2a3ce86b9aa4758bc487b4fdda30
MD5 964f1f4fdd57c78d3317b3ff47084e4f
BLAKE2b-256 bdf8a1de7f0e351a6d9a6cbfd8e4a8378f721ac70ef1bf9fac672c63d98d465b

See more details on using hashes here.

File details

Details for the file pangolin-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: pangolin-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 143.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.10.12 {"installer":{"name":"uv","version":"0.10.12","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Fedora Linux","version":"43","id":"","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for pangolin-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 16eab1be87a18e5075fe7cd9f7c06a9e68b89cae5f03d6225f9f2fec7a6be056
MD5 9b945822ea9bbd982cda092fa2b6c1bc
BLAKE2b-256 24e3723bc2f47a74aa50ed3064cc64af8e2e1fb525da146111227dab71a0c9b4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page