Skip to main content

A JAX implementation of the Metalog distribution for flexible probability modeling.

Project description

metalog-jax

metalog-jax

GPU-accelerated metalog distributions for modern probabilistic modeling

PyPI Python uv License

CI Coverage Documentation

Ruff ty Maintainability

Semgrep Bandit Grype SBOM

InstallationQuick StartFeaturesDocumentationCitation


metalog-jax

The Problem

Traditional probability distributions often fail to model real-world data accurately. You find yourself:

  • Trying multiple distributions, hoping one fits
  • Truncating distributions to enforce bounds
  • Using mixture models that are hard to interpret
  • Writing custom likelihood functions for edge cases

What if one distribution family could fit virtually any continuous data?

The Solution

metalog-jax implements the Metalog distribution—a revolutionary approach to probability modeling introduced by Tom Keelin (2016). Metalogs are a continuous, semi-parametric family that can represent virtually any probability distribution through quantile-based fitting.

import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.utils import DEFAULT_Y

# Your data
data = jnp.array([2.1, 3.5, 4.2, 5.8, 6.1, 7.3, 8.9, 12.4, 15.2, 18.7])

# Create validated input data
input_data = MetalogInputData.from_values(data, DEFAULT_Y, precomputed_quantiles=False)

# Configure metalog parameters
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fit the metalog distribution
metalog = fit(input_data, params)

# Use it like any scipy distribution
median = metalog.ppf(jnp.array([0.5]))   # Quantile function
density = metalog.pdf(jnp.array([0.5]))  # Probability density
cumulative = metalog.cdf(10.0)           # CDF at x=10

Why metalog-jax?

Feature metalog-jax pymetalog
GPU acceleration Yes No
Automatic differentiation Yes No
JIT compilation Yes No
Bounded distribution support Yes Yes
Multiple regression methods Yes (OLS, LASSO) OLS only
Hyperparameter grid search Yes (vectorized) No
Serialization (save/load) Yes No
Active development Yes No

Installation

Using pip

pip install metalog-jax

Using uv (recommended)

uv add metalog-jax

From source

git clone https://github.com/tjefferies/metalog_jax.git
cd metalog_jax
make install    # Install all dependencies

Requirements

  • Python >= 3.11
  • JAX >= 0.8.0
  • Flax >= 0.12.0
  • Chex >= 0.1.91
  • Plotly >= 6.4.0

Quick Start

Fitting a Distribution

import jax.numpy as jnp
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.utils import DEFAULT_Y

# Sample data (e.g., response times in microseconds)
response_times = jnp.array([
    45, 52, 58, 61, 67, 72, 78, 85, 93, 102,
    115, 128, 145, 167, 195, 238, 312, 456, 623, 892
])

# Create validated input data
data = MetalogInputData.from_values(response_times, DEFAULT_Y, precomputed_quantiles=False)

# Configure the metalog (bounded below by 0)
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fit the distribution
metalog = fit(data, params)

# Analyze your distribution
print(f"Median response time: {float(metalog.ppf(jnp.array([0.5]))[0]):.1f} ms")
print(f"95th percentile: {float(metalog.ppf(jnp.array([0.95]))[0]):.1f} ms")
print(f"99th percentile: {float(metalog.ppf(jnp.array([0.99]))[0]):.1f} ms")

Working with Bounded Data

Metalog natively supports four boundedness types:

from metalog_jax.base import MetalogParameters, MetalogBoundedness, MetalogFitMethod

# Unbounded: support on (-inf, +inf)
# Example: temperature anomalies, returns
params = MetalogParameters(
    boundedness=MetalogBoundedness.UNBOUNDED,
    lower_bound=0.0,  # ignored
    upper_bound=0.0,  # ignored
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Lower-bounded: support on (lower_bound, +inf)
# Example: response times, prices, distances
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,  # ignored
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Upper-bounded: support on (-inf, upper_bound)
# Example: time until deadline, remaining capacity
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_UPPER_BOUND,
    lower_bound=0.0,  # ignored
    upper_bound=100.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

# Fully bounded: support on (lower_bound, upper_bound)
# Example: percentages, probabilities, test scores
params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=100.0,
    method=MetalogFitMethod.OLS,
    num_terms=5,
)

Regularized Fitting

For noisy data or when using many terms, LASSO regularization improves stability:

from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.metalog import fit
from metalog_jax.regression import LassoParameters

# LASSO (L1 regularization for sparse coefficients)
lasso_params = LassoParameters(
    lam=0.01,              # L1 regularization strength
    learning_rate=1e-3,
    num_iters=1000,
    tol=1e-6,
    momentum=0.9,
)
params = MetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    lower_bound=0.0,
    upper_bound=0.0,
    method=MetalogFitMethod.Lasso,
    num_terms=6,
)
metalog = fit(data, params, regression_hyperparams=lasso_params)

SPT Metalog (3-Term Analytical Fitting)

For rapid approximation with minimal data, use the Symmetric Percentile Triplet method. SPT metalog computes coefficients analytically from just three quantiles and validates feasibility upfront—ensuring the resulting distribution has a valid (non-negative) PDF before returning. This fail-fast behavior prevents downstream errors from infeasible fits.

import jax.numpy as jnp
from metalog_jax.base import MetalogBoundedness, SPTMetalogParameters
from metalog_jax.metalog import fit_spt_metalog

# Sample data (e.g., response times in microseconds)
response_times = jnp.array([
    45, 52, 58, 61, 67, 72, 78, 85, 93, 102,
    115, 128, 145, 167, 195, 238, 312, 456, 623, 892
])

# Only needs 3 quantiles: alpha, median, 1-alpha
# Use STRICTLY_LOWER_BOUND for non-negative data like response times
spt_params = SPTMetalogParameters(
    boundedness=MetalogBoundedness.STRICTLY_LOWER_BOUND,
    alpha=0.1,  # Uses 10th, 50th, and 90th percentiles
    lower_bound=0.0,
    upper_bound=0.0,
)

spt_metalog = fit_spt_metalog(response_times, spt_params)
print(f"SPT Median: {float(spt_metalog.ppf(jnp.array([0.5]))[0]):.1f}")

Features

Full Probability Distribution API

import jax.numpy as jnp

# Quantile function (inverse CDF)
quantiles = metalog.ppf(jnp.array([0.1, 0.25, 0.5, 0.75, 0.9]))

# Probability density function
pdf_values = metalog.pdf(jnp.array([0.2, 0.4, 0.6, 0.8]))

# Log probability density (numerically stable)
log_pdf = metalog.logpdf(jnp.array([0.2, 0.4, 0.6, 0.8]))

# Cumulative distribution function
cdf_values = metalog.cdf(jnp.array([50.0, 100.0, 150.0]))

# Survival function (1 - CDF)
sf_values = metalog.sf(jnp.array([50.0, 100.0, 150.0]))

# Inverse survival function
isf_values = metalog.isf(jnp.array([0.1, 0.05, 0.01]))

# Summary statistics (properties, not methods)
mean = metalog.mean
variance = metalog.var
std_dev = metalog.std
mode = metalog.mode
median = metalog.median

Random Sampling

Two PRNG backends are supported for random variate generation:

from metalog_jax.base import MetalogRandomVariableParameters
from metalog_jax.utils import JaxUniformDistributionParameters, HDRPRNGParameters

# JAX-based random sampling (standard approach)
rv_params = MetalogRandomVariableParameters(
    prng_params=JaxUniformDistributionParameters(seed=42),
    size=10000,
)
samples = metalog.rvs(rv_params)

# HDR PRNG for reproducible Monte Carlo simulations
# Multi-dimensional, counter-based PRNG ideal for parallel simulations
hdr_params = HDRPRNGParameters(
    trial=1,       # Simulation trial/iteration
    variable=0,    # Random variable identifier
    entity=0,      # Entity being simulated
    time=0,        # Time step
    agent=0,       # Agent/actor identifier
)
rv_params = MetalogRandomVariableParameters(
    prng_params=hdr_params,
    size=10000,
)
samples = metalog.rvs(rv_params)

Interactive Visualization

from metalog_jax.base import MetalogPlotOptions

# Plot PDF
fig = metalog.plot(MetalogPlotOptions.PDF)
fig.show()

# Plot CDF
fig = metalog.plot(MetalogPlotOptions.CDF)
fig.show()

# Plot Survival Function
fig = metalog.plot(MetalogPlotOptions.SF)
fig.show()

Serialization

from pathlib import Path
from metalog_jax.metalog import Metalog

# Save to JSON
metalog.save(Path("my_distribution.json"))

# Load from JSON
loaded = Metalog.load(Path("my_distribution.json"))

# String serialization
json_str = metalog.dumps()
loaded2 = Metalog.loads(json_str)
assert metalog == loaded == loaded2

Unified Grid Search with fit_grid

The fit_grid function provides a unified interface for hyperparameter optimization, automatically detecting which axes to search based on inputs.

Grid Search over L1 Penalties

import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Generate sample data
np.random.seed(42)
samples = jnp.array(np.random.beta(2, 5, 100))

# Create validated input data
data = MetalogInputData.from_values(samples, DEFAULT_Y, precomputed_quantiles=False)

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

# Grid search over L1 penalties
l1_penalties = jnp.array([0.0, 0.001, 0.01, 0.1])
result = fit_grid(data.x, data.y, params, l1_penalties=l1_penalties)

# Find the best configuration
best_idx, best_ks = find_best_config(result.ks_dist)
print(f"Best L1 penalty: {float(l1_penalties[int(best_idx)]):.4f}")
print(f"Best KS distance: {float(best_ks):.4f}")

# Extract the best metalog for use
best_metalog = extract_metalog(result, int(best_idx))
print(f"Median: {float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")

2D Grid Search (L1 x num_terms)

import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Generate sample data
np.random.seed(42)
samples = jnp.array(np.random.beta(2, 5, 100))

# Create validated input data
data = MetalogInputData.from_values(samples, DEFAULT_Y, precomputed_quantiles=False)

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

# 2D grid search over L1 penalties and num_terms
l1_penalties = jnp.array([0.0, 0.01, 0.1])
num_terms_list = [5, 7, 9]

result = fit_grid(
    data.x, data.y, params,
    l1_penalties=l1_penalties,
    num_terms=num_terms_list
)

# Result shape is (len(l1_penalties), len(num_terms_list))
print(f"Grid shape: {result.ks_dist.shape}")  # (3, 3)

# Find the best configuration
best_idx, best_ks = find_best_config(result.ks_dist)
best_l1_idx, best_terms_idx = int(best_idx[0]), int(best_idx[1])
print(f"Best L1 penalty: {float(l1_penalties[best_l1_idx]):.4f}")
print(f"Best num_terms: {num_terms_list[best_terms_idx]}")
print(f"Best KS distance: {float(best_ks):.4f}")

# Extract the best metalog for use
best_metalog = extract_metalog(result, best_l1_idx, best_terms_idx)
print(f"Median: {float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")

Batch Multiple Datasets with Full 3D Grid

import jax.numpy as jnp
import numpy as np
from metalog_jax.base import MetalogInputData, MetalogParameters
from metalog_jax.base import MetalogBoundedness, MetalogFitMethod
from metalog_jax.grid_search import fit_grid, find_best_config, extract_metalog
from metalog_jax.utils import DEFAULT_Y

# Create 3 different datasets with different distributions
np.random.seed(42)
samples1 = jnp.array(np.random.beta(2, 5, 100))  # left-skewed
samples2 = jnp.array(np.random.beta(5, 2, 100))  # right-skewed
samples3 = jnp.array(np.random.beta(2, 2, 100))  # symmetric

data1 = MetalogInputData.from_values(samples1, DEFAULT_Y, precomputed_quantiles=False)
data2 = MetalogInputData.from_values(samples2, DEFAULT_Y, precomputed_quantiles=False)
data3 = MetalogInputData.from_values(samples3, DEFAULT_Y, precomputed_quantiles=False)

# Stack datasets for batch processing
batched_x = jnp.stack([data1.x, data2.x, data3.x])
batched_y = jnp.stack([data1.y, data2.y, data3.y])

params = MetalogParameters(
    boundedness=MetalogBoundedness.BOUNDED,
    lower_bound=0.0,
    upper_bound=1.0,
    method=MetalogFitMethod.Lasso,
    num_terms=5,
)

l1_penalties = jnp.array([0.0, 0.01])
num_terms_list = [5, 7]

result = fit_grid(
    batched_x, batched_y, params,
    l1_penalties=l1_penalties,
    num_terms=num_terms_list
)

# Result shape is (n_datasets, len(l1_penalties), len(num_terms_list))
print(f"Grid shape: {result.ks_dist.shape}")  # (3, 2, 2)

# Find best configuration for each dataset and extract the metalog
dataset_names = ["left-skewed", "right-skewed", "symmetric"]
for d, name in enumerate(dataset_names):
    best_idx, best_ks = find_best_config(result.ks_dist[d])
    best_l1_idx, best_terms_idx = int(best_idx[0]), int(best_idx[1])

    # Extract the best metalog for this dataset
    best_metalog = extract_metalog(result, d, best_l1_idx, best_terms_idx)

    print(f"Dataset '{name}': L1={float(l1_penalties[best_l1_idx]):.4f}, "
          f"terms={num_terms_list[best_terms_idx]}, KS={float(best_ks):.4f}, "
          f"median={float(best_metalog.ppf(jnp.array([0.5]))[0]):.4f}")

The function handles all 8 combinations of axes automatically:

  • Single/batched datasets
  • With/without L1 penalty grid
  • With/without num_terms grid

JAX Transformations

Full compatibility with JAX's transformation primitives:

import jax

# Automatic differentiation
def quantile_at(prob):
    return metalog.ppf(prob)

gradient_fn = jax.grad(quantile_at)
gradient = gradient_fn(0.5)

Choosing the Number of Terms

Terms Use Case Data Requirements
2 Simple symmetric distributions 6+ observations
3-4 Moderate skewness 9-12+ observations
5-6 Heavy tails, asymmetry 15-18+ observations
7-10 Complex multimodal shapes 21-30+ observations
10+ Highly irregular distributions 30+ observations

Rule of thumb: Use at least 3x observations per term.

Architecture

metalog_jax/
├── base/                    # Core abstractions
│   ├── core.py             # MetalogBase class with distribution methods
│   ├── data.py             # Input data validation and containers
│   ├── enums.py            # MetalogBoundedness, MetalogFitMethod
│   └── parameters.py       # Configuration dataclasses
├── regression/              # Fitting algorithms
│   ├── base.py             # RegressionModel, RegularizedParameters
│   ├── ols.py              # Ordinary Least Squares
│   └── lasso.py            # LASSO (L1 regularization)
├── metalog.py              # Metalog, SPTMetalog, fit, fit_spt_metalog
├── grid_search.py          # Unified fit_grid for hyperparameter optimization
└── utils.py                # HDRPRNG, KS distance, DEFAULT_Y, helpers

Comparison with Standard Distributions

Metalog can approximate any continuous distribution. Here's how it compares fitting various scipy distributions:

Distribution 5-term KS Distance 7-term KS Distance
Normal < 0.001 < 0.0001
Log-Normal < 0.002 < 0.0005
Gamma < 0.003 < 0.001
Beta < 0.002 < 0.0005
Weibull < 0.003 < 0.001
Chi-Square < 0.004 < 0.001
Student's t < 0.003 < 0.001

KS Distance: Kolmogorov-Smirnov distance (lower is better)

Documentation

Tutorials

Interactive notebooks are available in two formats:

Jupyter Notebooks (pre-executed, viewable in browser):

Marimo Notebooks (interactive, run locally):

  • marimo run examples/basic_usage.py
  • marimo run examples/fitting_grids.py

Contributing

We welcome contributions! Please follow these guidelines to ensure a smooth review process.

Before You Start

  1. Clone repo locally
  2. Create a feature branch: git checkout -b feature/amazing-feature
  3. Install dependencies: make install

Development Workflow

This project uses a Makefile that mirrors all CI/CD checks. Run make help to see all available targets.

Write and test your changes:

make test-quick    # Fast iteration (no coverage)
make test          # Full test suite with coverage

Check code quality before committing:

make format        # Auto-format code
make lint          # Check for issues
make typecheck     # Verify types

Build documentation (if you modified docstrings):

make docs          # Build once
make docs-live     # Live-reload during development

Before pushing, run the full quality gate to catch CI failures early:

make quality-gate

This runs all 8 checks that CI will run: formatting, linting, type checking, complexity metrics, tests, license compliance, and security scans.

Submitting Changes

  1. Commit with a clear message: git commit -m 'Add amazing feature'
  2. Push to your branch: git push origin feature/amazing-feature
  3. Open a Pull Request with a description of your changes

Note: Pull Requests with failing CI will not be reviewed. Run make quality-gate locally first.

Development Setup

git clone https://github.com/tjefferies/metalog_jax.git
cd metalog_jax
make install    # Install all dependencies
make test       # Run tests with coverage
make docs       # Build documentation
make help       # Show all available Make targets

Code Style

  • Follow existing code conventions in the repository
  • Use type hints for all function signatures
  • Write comprehensive docstrings following Google style
  • Keep functions focused and single-purpose
  • Prefer immutable data structures (Flax dataclasses)

Citation

If you use metalog-jax in your research, please cite:

@software{metalog_jax,
  author = {Jefferies, Travis},
  title = {metalog-jax: GPU-accelerated metalog distributions for JAX},
  year = {2026},
  url = {https://github.com/tjefferies/metalog_jax}
}

And the original metalog paper:

@article{keelin2016metalog,
  author = {Keelin, Thomas W.},
  title = {The Metalog Distributions},
  journal = {Decision Analysis},
  volume = {13},
  number = {4},
  pages = {243-277},
  year = {2016},
  doi = {10.1287/deca.2016.0338}
}

References

Metalog Distributions

Regression Methods

  • Hastie, T., Tibshirani, R., & Friedman, J. (2009). The Elements of Statistical Learning: Data Mining, Inference, and Prediction (2nd ed.). Springer. Chapter 3: Linear Methods for Regression.
  • Tibshirani, R. (1996). Regression Shrinkage and Selection via the Lasso. Journal of the Royal Statistical Society: Series B (Methodological), 58(1), 267-288.

Optimization Algorithms

Statistical Methods

  • Kolmogorov, A. N. (1933). Sulla determinazione empirica di una legge di distribuzione. Giornale dell'Istituto Italiano degli Attuari, 4, 83-91.
  • Smirnov, N. V. (1948). Table for estimating the goodness of fit of empirical distributions. Annals of Mathematical Statistics, 19(2), 279-281.

Random Number Generation

License

This project is licensed under the MIT License - see the LICENSE file for details.


Built with JAX by Travis Jefferies

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

metalog_jax-1.0.0-py3-none-any.whl (59.0 kB view details)

Uploaded Python 3

File details

Details for the file metalog_jax-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: metalog_jax-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 59.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for metalog_jax-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8eb68568f91de3b45b379ba5b79d5f799b5a9129a174c02ac0966c858408a810
MD5 f698c23c38f5cbb983d42e7ef455094e
BLAKE2b-256 f816386ca8167ef1b8d0d89f6d17a1816f5868afdbc7de192c62a04cbaeea5a0

See more details on using hashes here.

Provenance

The following attestation bundles were made for metalog_jax-1.0.0-py3-none-any.whl:

Publisher: publish.yml on tjefferies/metalog_jax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page