Scalable probabilistic impact modeling
Project description
aimz: Scalable probabilistic impact modeling
Overview
aimz is a Python library for flexible and scalable probabilistic impact modeling to assess the effects of interventions on outcomes of interest. Designed to work with user-defined models with probabilistic primitives, the library builds on NumPyro, JAX, Xarray, and Zarr to enable efficient inference workflows.
Features
- An intuitive API that combines ease of use from ML frameworks with the flexibility of probabilistic modeling.
- Scalable computation via parallelism and distributed data processing—no manual orchestration required.
- Variational inference as the primary inference engine, supporting custom optimization strategies and results.
- Support for interventional causal inference for modeling counterfactuals and causal relations.
- MLflow integration for experiment tracking and model management.
Usage
Workflow
- Outline the model, considering the data generating process, latent variables, and causal relationships, if any.
- Translate the model into a kernel (i.e., a function) using NumPyro and JAX.
- Integrate the kernel into the provided API to train the model and perform inference.
Example 1: Regression Using a scikit-learn-like Workflow
This example demonstrates a simple regression model following a typical ML workflow. The ImpactModel class provides .fit() and .fit_on_batch() for variational inference and posterior sampling, and .predict() and .predict_on_batch() for posterior predictive sampling. The optional .cleanup() removes posterior predictive samples saved as temporary files.
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from aimz import ImpactModel
# Load California Housing dataset
housing = fetch_california_housing()
X, y = housing.data, housing.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
# NumPyro model: linear regression
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
"""Bayesian linear regression."""
n_features = X.shape[1]
# Priors for weights, bias, and observation noise
w = sample("w", dist.Normal(jnp.zeros(n_features), jnp.ones(n_features)))
b = sample("b", dist.Normal())
sigma = sample("sigma", dist.Exponential())
# Plate over data
mu = jnp.dot(X, w) + b
with plate("data", X.shape[0]):
sample("y", dist.Normal(mu, sigma), obs=y)
# Wrap with ImpactModel
im = ImpactModel(
model,
rng_key=random.key(42),
inference=SVI(
model,
guide=AutoNormal(model),
optim=optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
),
)
# Fit the model: variational inference followed by posterior sampling
im.fit_on_batch(X_train, y_train)
# Predict on new data using posterior predictive sampling
idata = im.predict(X_test)
# Clean up posterior predictive samples saved to disk during `.predict()`
im.cleanup()
The training step can be skipped if pre-trained variational inference results or posterior samples are available. These can be integrated into the
ImpactModel, allowing.predict()to be available subsequently.
Example 2: Causal Network with Confounder
This example illustrates a simple causal network. The variable Z has a direct causal effect on the outcome Y, while both are influenced by a shared confounder, C. An additional variable, X, is an observed exogenous factor that influences Z but has no direct effect on Y.
Our objective is to estimate the causal effect of Z (or alternatively X) on Y, while properly accounting for the confounding influence of C. We assume the following generative model for the observed data:
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import nn, random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO, init_to_feasible
from numpyro.infer.autoguide import AutoNormal
from aimz import ImpactModel
# NumPyro model: Z and y are influenced by C and X, with Z mediating part of y
def model(X: ArrayLike, C: ArrayLike, y: ArrayLike | None = None) -> None:
# Observed confounder
c = sample("c", dist.Exponential(), obs=C)
# Priors for coefficients in the structural model
# C -> Z and C -> Y
beta_cz = sample("beta_cz", dist.Normal())
beta_cy = sample("beta_cy", dist.Normal())
# X -> Z and Z -> Y
beta_xz = sample("beta_xz", dist.Normal())
beta_zy = sample("beta_zy", dist.Normal())
# Intercepts
beta_z = sample("beta_z", dist.Normal())
beta_y = sample("beta_y", dist.Normal())
# Observation noise for Z
sigma = sample("sigma", dist.Exponential())
# Plate over data
with plate("data", X.shape[0]):
mu_z = beta_z + beta_cz * c + beta_xz * X.squeeze(axis=1)
z = sample("z", dist.LogNormal(mu_z, sigma))
logits = beta_y + beta_cy * c + beta_zy * z
sample("y", dist.Bernoulli(logits=logits), obs=y)
Simulating data under a known structural model
We generate synthetic data consistent with the assumed causal structure:
Cis drawn from an exponential distribution.Xis a count variable from a Poisson distribution.Zis generated as a noisy exponential function ofCandX.Yis a binary outcome influenced by bothCandZthrough a logistic model.
# Create a pseudo-random number generator key for JAX
rng_key = random.key(42)
# Sample C from an Exponential distribution
rng_key, rng_subkey = random.split(rng_key)
C = random.exponential(rng_subkey, shape=(100,))
# Sample X from a Poisson distribution
rng_key, rng_subkey = random.split(rng_key)
X = random.poisson(rng_subkey, lam=1, shape=(100, 1))
# Generate Z influenced by C and X
rng_key, rng_subkey = random.split(rng_key)
mu_z = -1.0 + 0.5 * C - 1.5 * X.squeeze()
sigma_z = 10.0 # Add substantial noise to reduce correlation between C and Z
Z = jnp.exp(random.normal(rng_subkey, shape=(100,)) * sigma_z + mu_z)
# Generate Y from a logistic regression on C and Z
rng_key, rng_subkey = random.split(rng_key)
logits = -2.0 + 5.0 * C + 0.1 * Z
p = nn.sigmoid(logits)
y = random.bernoulli(rng_subkey, p=p).astype(jnp.int32)
Fitting the model and estimating causal effects
We fit the model using stochastic variational inference. Once trained, we perform a counterfactual analysis to isolate the effect of Z on Y.
idata_factualrepresents predictions under the factual setting (with observedZ).idata_counterfactualrepresents predictions under a counterfactual intervention whereZis set to zero. Comparing these two distributions allows us to estimate the causal effect ofZonY, adjusted for the influence ofC.
# Fit the model with SVI
im = ImpactModel(
model,
rng_key=rng_key,
inference=SVI(
model,
guide=AutoNormal(model, init_loc_fn=init_to_feasible()),
optim=optim.Adam(step_size=1e-3),
loss=Trace_ELBO(),
),
)
im.fit_on_batch(X, y, C=C)
# Predict under factual (Z) and counterfactual (zeroed Z) scenarios
idata_factual = im.predict_on_batch(X, C=C, intervention={"z": Z})
idata_counterfactual = im.predict_on_batch(
X,
C=C,
intervention={"z": jnp.zeros_like(Z)},
)
# Estimate causal effect of intervening on Z while conditioning on C
impact = im.estimate_effect(
output_baseline=idata_factual,
output_intervention=idata_counterfactual,
)
Local latent variable requires
.predict_on_batch()here. Prefer.predict()whenever it is compatible with the model.
Getting Help
For feature requests, assistance, or any inquiries, contact maintainers or open an issue/pull request.
Contributing
See CONTRIBUTING.md to get started.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file aimz-0.6.0.tar.gz.
File metadata
- Download URL: aimz-0.6.0.tar.gz
- Upload date:
- Size: 46.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2fd086b7e867b812e977f01d22d718775e3c76712de6e9ca79ca75b6b54ed681
|
|
| MD5 |
8e9f7245bd00dfc21d58ac7969bd14d6
|
|
| BLAKE2b-256 |
e4a5278e34e7f4ccf271cd297948a8ff7e0fbfd9c531b64b6ea8d9e154d6c21a
|
Provenance
The following attestation bundles were made for aimz-0.6.0.tar.gz:
Publisher:
publish.yaml on markean/aimz
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
aimz-0.6.0.tar.gz -
Subject digest:
2fd086b7e867b812e977f01d22d718775e3c76712de6e9ca79ca75b6b54ed681 - Sigstore transparency entry: 516272386
- Sigstore integration time:
-
Permalink:
markean/aimz@7dfbb1de26cbd903b183528ed8af9b78c28aff4b -
Branch / Tag:
refs/heads/main - Owner: https://github.com/markean
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7dfbb1de26cbd903b183528ed8af9b78c28aff4b -
Trigger Event:
workflow_dispatch
-
Statement type:
File details
Details for the file aimz-0.6.0-py3-none-any.whl.
File metadata
- Download URL: aimz-0.6.0-py3-none-any.whl
- Upload date:
- Size: 50.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c1624936b22fe4aac69143c5d13332a3d688c205f3ae14927b8673890090afdb
|
|
| MD5 |
9fd225f71b10f1a761edb952448e22f8
|
|
| BLAKE2b-256 |
6714d762e8a2a670f8971ca9cfdd11f3c7930a00c4041168bd054c3614a4775f
|
Provenance
The following attestation bundles were made for aimz-0.6.0-py3-none-any.whl:
Publisher:
publish.yaml on markean/aimz
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
aimz-0.6.0-py3-none-any.whl -
Subject digest:
c1624936b22fe4aac69143c5d13332a3d688c205f3ae14927b8673890090afdb - Sigstore transparency entry: 516272391
- Sigstore integration time:
-
Permalink:
markean/aimz@7dfbb1de26cbd903b183528ed8af9b78c28aff4b -
Branch / Tag:
refs/heads/main - Owner: https://github.com/markean
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yaml@7dfbb1de26cbd903b183528ed8af9b78c28aff4b -
Trigger Event:
workflow_dispatch
-
Statement type: