Skip to main content

Adaptive Importance Sampling and quilted-model scaling for Bayesian cross-validation

Project description

Bayesianquilts

A JAX-based library for building interpretable Bayesian models using piecewise linear regression and gradient-flow adaptive importance sampling for leave-one-out cross-validation.

Overview

Bayesianquilts provides tools for building truly interpretable input-output maps based on the principle of piecewise linearity. Rather than using black-box neural networks, this library combines representation learning, clustering, and multilevel linear regression modeling to create transparent, interpretable models suitable for high-stakes applications like healthcare and scientific research.

The library includes two major research contributions:

  1. Piecewise Linear Regression Models: An additive decomposition approach where parameter values arise as sums of contributions at different length scales
  2. Gradient-Flow Adaptive Importance Sampling (AIS): Advanced methods for Bayesian leave-one-out cross-validation

Key Features

  • Interpretable by Design: Models are constructed to be inherently interpretable, not just post-hoc explainable
  • Parameter Decomposition: Additive decomposition of parameters across interaction dimensions
  • Flexible Model Types: Supports classification, regression, and matrix factorization
  • Advanced Cross-Validation: Gradient-flow adaptive importance sampling for LOO-CV
  • Bayesian Inference: Full support for variational inference (ADVI) and importance sampling
  • JAX-Accelerated: Built on JAX for GPU/TPU acceleration and automatic differentiation
  • Robust Training: Includes gradient clipping, learning rate scheduling, NaN recovery, and checkpointing

Installation

From PyPI (once published)

pip install bayesianquilts

From Source

git clone https://github.com/mederrata/bayesianquilts.git
cd bayesianquilts
pip install -r requirements.txt
pip install -e .

Requirements

  • Python >= 3.8
  • JAX >= 0.7.1
  • TensorFlow Probability >= 0.25.0
  • Flax >= 0.11.2
  • NumPy, Pandas, SciPy, Scikit-learn
  • Optax (optimization)
  • Orbax (checkpointing)
  • ArviZ (Bayesian diagnostics)

See requirements.txt for complete dependency list.

Quick Start

Piecewise Linear Classification

from bayesianquilts.predictors.classification import LogisticBayesianquilt
from bayesianquilts.util import training_loop
import jax.numpy as jnp

# Prepare your data
X_train = jnp.array(...)  # Features
y_train = jnp.array(...)  # Labels

# Create model
model = LogisticBayesianquilt(
    num_features=X_train.shape[1],
    num_classes=2
)

# Initialize parameters
params = model.initialize(random_key)

# Train with built-in utilities
losses, trained_params = training_loop(
    initial_values=params,
    loss_fn=lambda p: model.loss(p, X_train, y_train),
    num_epochs=100,
    learning_rate=0.01,
    clip_norm=1.0,
    patience=10
)

Adaptive Importance Sampling for LOO-CV

from bayesianquilts.metrics.ais import (
    AdaptiveImportanceSampler,
    LogisticRegressionLikelihood
)

# Define likelihood function
likelihood_fn = LogisticRegressionLikelihood()

# Create AIS sampler
ais_sampler = AdaptiveImportanceSampler(
    likelihood_fn,
    prior_log_prob_fn=prior_fn,
    surrogate_log_prob_fn=surrogate_fn
)

# Compute LOO-CV with multiple transformation strategies
results = ais_sampler.adaptive_is_loo(
    data={'X': X, 'y': y},
    params=trained_params,
    hbar=1.0,
    variational=False,
    transformations=['ll', 'kl', 'var', 'identity']
)

# Access results
print(f"LOO log-likelihood: {results['ll_loo_psis']}")
print(f"Effective parameters (p_loo): {results['p_loo_psis']}")
print(f"PSIS k-hat diagnostic: {results['khat']}")

Core Concepts

Parameter Decomposition

The fundamental innovation is an additive decomposition of model parameters:

θ_effective = θ_global + θ_group1 + θ_group2 + ... + θ_local

Each parameter value arises as a sum of contributions at different hierarchical levels (length scales), enabling:

  • Automatic regularization through hierarchical priors
  • Interpretable multi-level effects
  • Interaction modeling across categorical and continuous variables

See notebooks/decomposition.ipynb for detailed examples.

Adaptive Importance Sampling

The AIS framework implements gradient-flow transformations for stable LOO-CV:

  • T_ll: Likelihood descent using negative log-likelihood gradients
  • T_kl: KL-divergence weighted gradients using posterior weights
  • T_var: Variance-based adaptation using Hessian curvature
  • T_I: Identity (baseline, no transformation)

Combined with Pareto Smoothed Importance Sampling (PSIS) for robust weight estimation.

Available Models

Classification

  • LogisticBayesianquilt: Piecewise linear logistic regression
  • LogisticRegression: Standard Bayesian logistic regression with decomposition
  • LogisticRelunet: ReLU neural network classifier
  • LogisticGamiNet: Generalized additive model with neural networks
  • LogisticRidge: Ridge-regularized logistic regression

Regression

  • RegressionQuilt: Piecewise linear regression
  • HierarchicalAttention: Attention-based regression

Matrix Factorization

  • GaussianFactorization: Continuous latent factor models
  • PoissonFactorization: Count data factorization
  • BernoulliFactorization: Binary data factorization

Neural Network Components

  • DenseHorseshoe: Dense layers with horseshoe priors
  • DenseGaussian: Dense layers with Gaussian priors
  • GamiNetUnivariate: Univariate shape functions
  • GamiNetPairwise: Pairwise interaction networks

Training Utilities

The util.py module provides robust training infrastructure:

from bayesianquilts.util import training_loop

losses, params = training_loop(
    initial_values=initial_params,
    loss_fn=loss_function,
    data_iterator=data_batches,
    steps_per_epoch=100,
    num_epochs=50,
    learning_rate=0.01,
    clip_norm=1.0,              # Gradient clipping
    patience=10,                 # Early stopping
    lr_decay_factor=0.5,        # Learning rate decay
    checkpoint_dir="./ckpts",   # Automatic checkpointing
    recover_from_nan=True       # NaN recovery strategies
)

Features:

  • Gradient clipping for stability
  • Learning rate scheduling with decay
  • Early stopping with patience
  • Automatic checkpointing with Orbax
  • NaN/Inf detection and recovery
  • Progress tracking with tqdm

Custom Distributions

Bayesianquilts includes several custom probability distributions:

  • GeneralizedGamma: Flexible shape for positive continuous data
  • PiecewiseExponential: For survival/duration modeling
  • TransformedHorseshoe: Sparsity-inducing priors
  • TransformedCauchy: Heavy-tailed priors
  • TransformedInverseGamma: Scale parameter priors

Examples and Notebooks

  • notebooks/decomposition.ipynb: Parameter decomposition methodology
  • notebooks/ovarian/: Medical claims modeling examples
  • notebooks/roach/: Logistic regression case studies
  • notebooks/enset/: Model comparison demonstrations
  • test_ais_framework.py: Complete AIS usage examples
  • test_gradient_clipping.py: Training utilities demonstration

Publications

This library implements methods from:

Piecewise Linear Models

  • Chang TL, Xia H, Mahajan S, Mahajan R, Maisog J, et al. (2024). Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death. PLOS ONE 19(5): e0302871. https://doi.org/10.1371/journal.pone.0302871

  • Xia H, Chang JC, Nowak S, Mahajan S, Mahajan R, Chang TL, Chow CC (2023). Proceedings of the 8th Machine Learning for Healthcare Conference, PMLR 219:884-905.

Adaptive Importance Sampling

  • Chang JC, Li X, Xu S, Yao HR, Porcino J, Chow CC (2024). Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models. ArXiv [Preprint] 2402.08151v2. PMID: 38711425; PMCID: PMC11071546. https://arxiv.org/abs/2402.08151

Project Structure

bayesianquilts/
├── bayesianquilts/
│   ├── model.py              # Base BayesianModel class
│   ├── util.py               # Training loops and utilities
│   ├── features.py           # Feature engineering
│   ├── jax/
│   │   └── parameter.py      # Parameter decomposition (Decomposed, Interactions)
│   ├── predictors/
│   │   ├── classification/   # Classification models
│   │   ├── regression/       # Regression models
│   │   ├── nn/               # Neural network components
│   │   └── factorization/    # Matrix factorization
│   ├── metrics/
│   │   ├── ais.py           # Adaptive importance sampling
│   │   ├── psis.py          # Pareto smoothed IS
│   │   └── nppsis.py        # NumPy/JAX PSIS
│   ├── vi/
│   │   ├── advi.py          # ADVI implementation
│   │   └── minibatch.py     # Minibatch VI
│   ├── distributions/        # Custom distributions
│   └── plotting/            # Visualization utilities
├── notebooks/               # Example notebooks
├── requirements.txt         # Dependencies
└── setup.py                # Package setup

API Status

The API is currently evolving as we prepare manuscripts on the methodology and theory. We will stabilize the API in future releases. For production use, please pin to specific versions.

Contributing

Contributions are welcome! Please:

  1. Fork the repository
  2. Create a feature branch
  3. Add tests for new functionality
  4. Ensure all tests pass
  5. Submit a pull request

License

MIT License - see LICENSE file for details.

Contact and Support

Supporting This Project

Mederrata Research LLC is a 501(c)3 non-profit organization. Tax-deductible monetary contributions are welcome and help support the development and maintenance of open-source tools for interpretable machine learning in healthcare and scientific research.

To make a contribution or learn more, please contact us at info@mederrata.com.

Citation

If you use this library in your research, please cite:

@article{chang2024interpretable,
  title={Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death},
  author={Chang, Ted L and Xia, Hongjing and Mahajan, Sonya and Mahajan, Rohit and Maisog, Jose and others},
  journal={PLOS ONE},
  volume={19},
  number={5},
  pages={e0302871},
  year={2024},
  publisher={Public Library of Science}
}

@article{chang2024gradient,
  title={Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models},
  author={Chang, Joshua C and Li, Xu and Xu, Shuang and Yao, Howard R and Porcino, John and Chow, Carson C},
  journal={arXiv preprint arXiv:2402.08151},
  year={2024}
}

Acknowledgments

This work was developed by the Mederrata Research team with support from the research community. Special thanks to all contributors and users who have provided feedback and helped improve the library.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

bayesianquilts-0.9.0.tar.gz (1.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

bayesianquilts-0.9.0-py3-none-any.whl (1.9 MB view details)

Uploaded Python 3

File details

Details for the file bayesianquilts-0.9.0.tar.gz.

File metadata

  • Download URL: bayesianquilts-0.9.0.tar.gz
  • Upload date:
  • Size: 1.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bayesianquilts-0.9.0.tar.gz
Algorithm Hash digest
SHA256 7808341bba5a1516fbe1f047d5a2872dd355861061ec90f008f179a29b7093f5
MD5 46b01faa8e66bdd554e097456716c41f
BLAKE2b-256 c64343ee13f82fa1d381fc468981eb0561b5ab09293f048133e45e99d82ecffb

See more details on using hashes here.

File details

Details for the file bayesianquilts-0.9.0-py3-none-any.whl.

File metadata

  • Download URL: bayesianquilts-0.9.0-py3-none-any.whl
  • Upload date:
  • Size: 1.9 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for bayesianquilts-0.9.0-py3-none-any.whl
Algorithm Hash digest
SHA256 40dca21c4211ca44e8e2593b5f61b50fa2bf70f38b2483811300a6a3f24bad53
MD5 35432921d577e53660a4e70e6083d211
BLAKE2b-256 e1fa24dabcf53839793189142af4fdb8880d6826bcc115648bb9a0472f1ac3a6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page