Skip to main content

Scalable probabilistic impact modeling

Project description

aimz: Scalable probabilistic impact modeling

Project Status: Active – The project has reached a stable, usable state and is being actively developed. Run Pytest PyPI Conda Python License: Apache 2.0 Code style: ruff codecov DOI

Installation | Documentation

Overview

aimz is a Python library for flexible and scalable probabilistic impact modeling to assess the effects of interventions on outcomes of interest. Designed to work with user-defined models with probabilistic primitives, the library builds on NumPyro, JAX, Xarray, and Zarr to enable efficient inference workflows.

Features

  • An intuitive API that combines ease of use from ML frameworks with the flexibility of probabilistic modeling.
  • Scalable computation via parallelism and distributed data processing—no manual orchestration required.
  • Variational inference as the primary inference engine, supporting custom optimization strategies and results.
  • Support for interventional causal inference for modeling counterfactuals and causal relations.
  • MLflow integration for experiment tracking and model management.

Usage

Workflow

  1. Outline the model, considering the data generating process, latent variables, and causal relationships, if any.
  2. Translate the model into a kernel (i.e., a function) using NumPyro and JAX.
  3. Integrate the kernel into the provided API to train the model and perform inference.

Example 1: Regression Using a scikit-learn-like Workflow

This example demonstrates a simple regression model following a typical ML workflow. The ImpactModel class provides .fit() and .fit_on_batch() for variational inference and posterior sampling, and .predict() and .predict_on_batch() for posterior predictive sampling. The optional .cleanup() removes posterior predictive samples saved as temporary files.

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

from aimz import ImpactModel

# Load California Housing dataset
housing = fetch_california_housing()
X, y = housing.data, housing.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)


# NumPyro model: linear regression
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    """Bayesian linear regression."""
    n_features = X.shape[1]

    # Priors for weights, bias, and observation noise
    w = sample("w", dist.Normal(jnp.zeros(n_features), jnp.ones(n_features)))
    b = sample("b", dist.Normal())
    sigma = sample("sigma", dist.Exponential())

    # Plate over data
    mu = jnp.dot(X, w) + b
    with plate("data", X.shape[0]):
        sample("y", dist.Normal(mu, sigma), obs=y)


# Wrap with ImpactModel
im = ImpactModel(
    model,
    rng_key=random.key(42),
    inference=SVI(
        model,
        guide=AutoNormal(model),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
)

# Fit the model: variational inference followed by posterior sampling
im.fit_on_batch(X_train, y_train)

# Predict on new data using posterior predictive sampling
idata = im.predict(X_test)

# Clean up posterior predictive samples saved to disk during `.predict()`
im.cleanup()

The training step can be skipped if pre-trained variational inference results or posterior samples are available. These can be integrated into the ImpactModel, allowing .predict() to be available subsequently.

Example 2: Causal Network with Confounder

This example illustrates a simple causal network. The variable Z has a direct causal effect on the outcome Y, while both are influenced by a shared confounder, C. An additional variable, X, is an observed exogenous factor that influences Z but has no direct effect on Y.

Our objective is to estimate the causal effect of Z (or alternatively X) on Y, while properly accounting for the confounding influence of C. We assume the following generative model for the observed data:

import jax.numpy as jnp
import numpyro.distributions as dist
from jax import nn, random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO, init_to_feasible
from numpyro.infer.autoguide import AutoNormal

from aimz import ImpactModel


# NumPyro model: Z and y are influenced by C and X, with Z mediating part of y
def model(X: ArrayLike, C: ArrayLike, y: ArrayLike | None = None) -> None:
    # Observed confounder
    c = sample("c", dist.Exponential(), obs=C)

    # Priors for coefficients in the structural model
    # C -> Z and C -> Y
    beta_cz = sample("beta_cz", dist.Normal())
    beta_cy = sample("beta_cy", dist.Normal())

    # X -> Z and Z -> Y
    beta_xz = sample("beta_xz", dist.Normal())
    beta_zy = sample("beta_zy", dist.Normal())

    # Intercepts
    beta_z = sample("beta_z", dist.Normal())
    beta_y = sample("beta_y", dist.Normal())

    # Observation noise for Z
    sigma = sample("sigma", dist.Exponential())

    # Plate over data
    with plate("data", X.shape[0]):
        mu_z = beta_z + beta_cz * c + beta_xz * X.squeeze(axis=1)
        z = sample("z", dist.LogNormal(mu_z, sigma))

        logits = beta_y + beta_cy * c + beta_zy * z
        sample("y", dist.Bernoulli(logits=logits), obs=y)

Simulating data under a known structural model

We generate synthetic data consistent with the assumed causal structure:

  • C is drawn from an exponential distribution.
  • X is a count variable from a Poisson distribution.
  • Z is generated as a noisy exponential function of C and X.
  • Y is a binary outcome influenced by both C and Z through a logistic model.
# Create a pseudo-random number generator key for JAX
rng_key = random.key(42)

# Sample C from an Exponential distribution
rng_key, rng_subkey = random.split(rng_key)
C = random.exponential(rng_subkey, shape=(100,))

# Sample X from a Poisson distribution
rng_key, rng_subkey = random.split(rng_key)
X = random.poisson(rng_subkey, lam=1, shape=(100, 1))

# Generate Z influenced by C and X
rng_key, rng_subkey = random.split(rng_key)
mu_z = -1.0 + 0.5 * C - 1.5 * X.squeeze()
sigma_z = 10.0  # Add substantial noise to reduce correlation between C and Z
Z = jnp.exp(random.normal(rng_subkey, shape=(100,)) * sigma_z + mu_z)

# Generate Y from a logistic regression on C and Z
rng_key, rng_subkey = random.split(rng_key)
logits = -2.0 + 5.0 * C + 0.1 * Z
p = nn.sigmoid(logits)
y = random.bernoulli(rng_subkey, p=p).astype(jnp.int32)

Fitting the model and estimating causal effects

We fit the model using stochastic variational inference. Once trained, we perform a counterfactual analysis to isolate the effect of Z on Y.

  • idata_factual represents predictions under the factual setting (with observed Z).
  • idata_counterfactual represents predictions under a counterfactual intervention where Z is set to zero. Comparing these two distributions allows us to estimate the causal effect of Z on Y, adjusted for the influence of C.
# Fit the model with SVI
im = ImpactModel(
    model,
    rng_key=rng_key,
    inference=SVI(
        model,
        guide=AutoNormal(model, init_loc_fn=init_to_feasible()),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
)
im.fit_on_batch(X, y, C=C)

# Predict under factual (Z) and counterfactual (zeroed Z) scenarios
idata_factual = im.predict_on_batch(X, C=C, intervention={"z": Z})
idata_counterfactual = im.predict_on_batch(
    X,
    C=C,
    intervention={"z": jnp.zeros_like(Z)},
)

# Estimate causal effect of intervening on Z while conditioning on C
impact = im.estimate_effect(
    output_baseline=idata_factual,
    output_intervention=idata_counterfactual,
)

Local latent variable requires .predict_on_batch() here. Prefer .predict() whenever it is compatible with the model.

Getting Help

For feature requests, assistance, or any inquiries, contact maintainers or open an issue/pull request.

Contributing

See CONTRIBUTING.md to get started.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aimz-0.6.0.tar.gz (46.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

aimz-0.6.0-py3-none-any.whl (50.9 kB view details)

Uploaded Python 3

File details

Details for the file aimz-0.6.0.tar.gz.

File metadata

  • Download URL: aimz-0.6.0.tar.gz
  • Upload date:
  • Size: 46.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for aimz-0.6.0.tar.gz
Algorithm Hash digest
SHA256 2fd086b7e867b812e977f01d22d718775e3c76712de6e9ca79ca75b6b54ed681
MD5 8e9f7245bd00dfc21d58ac7969bd14d6
BLAKE2b-256 e4a5278e34e7f4ccf271cd297948a8ff7e0fbfd9c531b64b6ea8d9e154d6c21a

See more details on using hashes here.

Provenance

The following attestation bundles were made for aimz-0.6.0.tar.gz:

Publisher: publish.yaml on markean/aimz

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file aimz-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: aimz-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 50.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for aimz-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c1624936b22fe4aac69143c5d13332a3d688c205f3ae14927b8673890090afdb
MD5 9fd225f71b10f1a761edb952448e22f8
BLAKE2b-256 6714d762e8a2a670f8971ca9cfdd11f3c7930a00c4041168bd054c3614a4775f

See more details on using hashes here.

Provenance

The following attestation bundles were made for aimz-0.6.0-py3-none-any.whl:

Publisher: publish.yaml on markean/aimz

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page