Adaptive Importance Sampling and quilted-model scaling for Bayesian cross-validation
Project description
Bayesianquilts
A JAX-based library for building interpretable Bayesian models using piecewise linear regression and gradient-flow adaptive importance sampling for leave-one-out cross-validation.
Overview
Bayesianquilts provides tools for building truly interpretable input-output maps based on the principle of piecewise linearity. Rather than using black-box neural networks, this library combines representation learning, clustering, and multilevel linear regression modeling to create transparent, interpretable models suitable for high-stakes applications like healthcare and scientific research.
The library includes two major research contributions:
- Piecewise Linear Regression Models: An additive decomposition approach where parameter values arise as sums of contributions at different length scales
- Gradient-Flow Adaptive Importance Sampling (AIS): Advanced methods for Bayesian leave-one-out cross-validation
Key Features
- Interpretable by Design: Models are constructed to be inherently interpretable, not just post-hoc explainable
- Parameter Decomposition: Additive decomposition of parameters across interaction dimensions
- Flexible Model Types: Supports classification, regression, and matrix factorization
- Advanced Cross-Validation: Gradient-flow adaptive importance sampling for LOO-CV
- Bayesian Inference: Full support for variational inference (ADVI) and importance sampling
- JAX-Accelerated: Built on JAX for GPU/TPU acceleration and automatic differentiation
- Robust Training: Includes gradient clipping, learning rate scheduling, NaN recovery, and checkpointing
Installation
From PyPI (once published)
pip install bayesianquilts
From Source
git clone https://github.com/mederrata/bayesianquilts.git
cd bayesianquilts
pip install -r requirements.txt
pip install -e .
Requirements
- Python >= 3.8
- JAX >= 0.7.1
- TensorFlow Probability >= 0.25.0
- Flax >= 0.11.2
- NumPy, Pandas, SciPy, Scikit-learn
- Optax (optimization)
- Orbax (checkpointing)
- ArviZ (Bayesian diagnostics)
See requirements.txt for complete dependency list.
Quick Start
Piecewise Linear Classification
from bayesianquilts.predictors.classification import LogisticBayesianquilt
from bayesianquilts.util import training_loop
import jax.numpy as jnp
# Prepare your data
X_train = jnp.array(...) # Features
y_train = jnp.array(...) # Labels
# Create model
model = LogisticBayesianquilt(
num_features=X_train.shape[1],
num_classes=2
)
# Initialize parameters
params = model.initialize(random_key)
# Train with built-in utilities
losses, trained_params = training_loop(
initial_values=params,
loss_fn=lambda p: model.loss(p, X_train, y_train),
num_epochs=100,
learning_rate=0.01,
clip_norm=1.0,
patience=10
)
Adaptive Importance Sampling for LOO-CV
from bayesianquilts.metrics.ais import (
AdaptiveImportanceSampler,
LogisticRegressionLikelihood
)
# Define likelihood function
likelihood_fn = LogisticRegressionLikelihood()
# Create AIS sampler
ais_sampler = AdaptiveImportanceSampler(
likelihood_fn,
prior_log_prob_fn=prior_fn,
surrogate_log_prob_fn=surrogate_fn
)
# Compute LOO-CV with multiple transformation strategies
results = ais_sampler.adaptive_is_loo(
data={'X': X, 'y': y},
params=trained_params,
hbar=1.0,
variational=False,
transformations=['ll', 'kl', 'var', 'identity']
)
# Access results
print(f"LOO log-likelihood: {results['ll_loo_psis']}")
print(f"Effective parameters (p_loo): {results['p_loo_psis']}")
print(f"PSIS k-hat diagnostic: {results['khat']}")
Core Concepts
Parameter Decomposition
The fundamental innovation is an additive decomposition of model parameters:
θ_effective = θ_global + θ_group1 + θ_group2 + ... + θ_local
Each parameter value arises as a sum of contributions at different hierarchical levels (length scales), enabling:
- Automatic regularization through hierarchical priors
- Interpretable multi-level effects
- Interaction modeling across categorical and continuous variables
See notebooks/decomposition.ipynb for detailed examples.
Adaptive Importance Sampling
The AIS framework implements gradient-flow transformations for stable LOO-CV:
- T_ll: Likelihood descent using negative log-likelihood gradients
- T_kl: KL-divergence weighted gradients using posterior weights
- T_var: Variance-based adaptation using Hessian curvature
- T_I: Identity (baseline, no transformation)
Combined with Pareto Smoothed Importance Sampling (PSIS) for robust weight estimation.
Available Models
Classification
LogisticBayesianquilt: Piecewise linear logistic regressionLogisticRegression: Standard Bayesian logistic regression with decompositionLogisticRelunet: ReLU neural network classifierLogisticGamiNet: Generalized additive model with neural networksLogisticRidge: Ridge-regularized logistic regression
Regression
RegressionQuilt: Piecewise linear regressionHierarchicalAttention: Attention-based regression
Matrix Factorization
GaussianFactorization: Continuous latent factor modelsPoissonFactorization: Count data factorizationBernoulliFactorization: Binary data factorization
Neural Network Components
DenseHorseshoe: Dense layers with horseshoe priorsDenseGaussian: Dense layers with Gaussian priorsGamiNetUnivariate: Univariate shape functionsGamiNetPairwise: Pairwise interaction networks
Training Utilities
The util.py module provides robust training infrastructure:
from bayesianquilts.util import training_loop
losses, params = training_loop(
initial_values=initial_params,
loss_fn=loss_function,
data_iterator=data_batches,
steps_per_epoch=100,
num_epochs=50,
learning_rate=0.01,
clip_norm=1.0, # Gradient clipping
patience=10, # Early stopping
lr_decay_factor=0.5, # Learning rate decay
checkpoint_dir="./ckpts", # Automatic checkpointing
recover_from_nan=True # NaN recovery strategies
)
Features:
- Gradient clipping for stability
- Learning rate scheduling with decay
- Early stopping with patience
- Automatic checkpointing with Orbax
- NaN/Inf detection and recovery
- Progress tracking with tqdm
Custom Distributions
Bayesianquilts includes several custom probability distributions:
GeneralizedGamma: Flexible shape for positive continuous dataPiecewiseExponential: For survival/duration modelingTransformedHorseshoe: Sparsity-inducing priorsTransformedCauchy: Heavy-tailed priorsTransformedInverseGamma: Scale parameter priors
Examples and Notebooks
notebooks/decomposition.ipynb: Parameter decomposition methodologynotebooks/ovarian/: Medical claims modeling examplesnotebooks/roach/: Logistic regression case studiesnotebooks/enset/: Model comparison demonstrationstest_ais_framework.py: Complete AIS usage examplestest_gradient_clipping.py: Training utilities demonstration
Publications
This library implements methods from:
Piecewise Linear Models
-
Chang TL, Xia H, Mahajan S, Mahajan R, Maisog J, et al. (2024). Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death. PLOS ONE 19(5): e0302871. https://doi.org/10.1371/journal.pone.0302871
-
Xia H, Chang JC, Nowak S, Mahajan S, Mahajan R, Chang TL, Chow CC (2023). Proceedings of the 8th Machine Learning for Healthcare Conference, PMLR 219:884-905.
Adaptive Importance Sampling
- Chang JC, Li X, Xu S, Yao HR, Porcino J, Chow CC (2024). Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models. ArXiv [Preprint] 2402.08151v2. PMID: 38711425; PMCID: PMC11071546. https://arxiv.org/abs/2402.08151
Project Structure
bayesianquilts/
├── bayesianquilts/
│ ├── model.py # Base BayesianModel class
│ ├── util.py # Training loops and utilities
│ ├── features.py # Feature engineering
│ ├── jax/
│ │ └── parameter.py # Parameter decomposition (Decomposed, Interactions)
│ ├── predictors/
│ │ ├── classification/ # Classification models
│ │ ├── regression/ # Regression models
│ │ ├── nn/ # Neural network components
│ │ └── factorization/ # Matrix factorization
│ ├── metrics/
│ │ ├── ais.py # Adaptive importance sampling
│ │ ├── psis.py # Pareto smoothed IS
│ │ └── nppsis.py # NumPy/JAX PSIS
│ ├── vi/
│ │ ├── advi.py # ADVI implementation
│ │ └── minibatch.py # Minibatch VI
│ ├── distributions/ # Custom distributions
│ └── plotting/ # Visualization utilities
├── notebooks/ # Example notebooks
├── requirements.txt # Dependencies
└── setup.py # Package setup
API Status
The API is currently evolving as we prepare manuscripts on the methodology and theory. We will stabilize the API in future releases. For production use, please pin to specific versions.
Contributing
Contributions are welcome! Please:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure all tests pass
- Submit a pull request
License
MIT License - see LICENSE file for details.
Contact and Support
- Organization: Mederrata Research LLC (501(c)3 non-profit)
- Email: info@mederrata.com
- Repository: https://github.com/mederrata/bayesianquilts
Supporting This Project
Mederrata Research LLC is a 501(c)3 non-profit organization. Tax-deductible monetary contributions are welcome and help support the development and maintenance of open-source tools for interpretable machine learning in healthcare and scientific research.
To make a contribution or learn more, please contact us at info@mederrata.com.
Citation
If you use this library in your research, please cite:
@article{chang2024interpretable,
title={Interpretable (not just posthoc-explainable) medical claims modeling for discharge placement to reduce preventable all-cause readmissions or death},
author={Chang, Ted L and Xia, Hongjing and Mahajan, Sonya and Mahajan, Rohit and Maisog, Jose and others},
journal={PLOS ONE},
volume={19},
number={5},
pages={e0302871},
year={2024},
publisher={Public Library of Science}
}
@article{chang2024gradient,
title={Gradient-flow adaptive importance sampling for Bayesian leave one out cross-validation with application to sigmoidal classification models},
author={Chang, Joshua C and Li, Xu and Xu, Shuang and Yao, Howard R and Porcino, John and Chow, Carson C},
journal={arXiv preprint arXiv:2402.08151},
year={2024}
}
Acknowledgments
This work was developed by the Mederrata Research team with support from the research community. Special thanks to all contributors and users who have provided feedback and helped improve the library.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file bayesianquilts-0.9.0.tar.gz.
File metadata
- Download URL: bayesianquilts-0.9.0.tar.gz
- Upload date:
- Size: 1.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7808341bba5a1516fbe1f047d5a2872dd355861061ec90f008f179a29b7093f5
|
|
| MD5 |
46b01faa8e66bdd554e097456716c41f
|
|
| BLAKE2b-256 |
c64343ee13f82fa1d381fc468981eb0561b5ab09293f048133e45e99d82ecffb
|
File details
Details for the file bayesianquilts-0.9.0-py3-none-any.whl.
File metadata
- Download URL: bayesianquilts-0.9.0-py3-none-any.whl
- Upload date:
- Size: 1.9 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
40dca21c4211ca44e8e2593b5f61b50fa2bf70f38b2483811300a6a3f24bad53
|
|
| MD5 |
35432921d577e53660a4e70e6083d211
|
|
| BLAKE2b-256 |
e1fa24dabcf53839793189142af4fdb8880d6826bcc115648bb9a0472f1ac3a6
|