Skip to main content

High-performance Bayesian inference engine written in Rust

Project description

rustmc

Bayesian inference engine written in Rust. Python API via PyO3.

Why rustmc

rustmc is built for production workloads where the same model structure is fit repeatedly:

  • Rust-native inference loop with no Python in the hot path.
  • Rayon-parallel chains and batch inference for repeated-model throughput.
  • Graph-based execution with cached buffers, transforms, and Jacobians.
  • Fast paths for linear regression and high-dimensional X @ beta models.
  • Built-in diagnostics, predictive checks, pointwise log-likelihood, and ArviZ export.

It is a strong fit for repeated Bayesian regression, forecasting, and hierarchical workflows on CPU. It is not yet a full arbitrary-PPL replacement for PyMC or Stan.

What sets rustmc apart is the execution model: it shares one compiled Rust core across chains and across many independent models, so throughput stays high when the same structure is applied to thousands of datasets.

PyMC and Stan are excellent general-purpose tools, but they are optimized around a broader single-model workflow. rustmc is optimized for the repeated-model setting where Python orchestration, per-model overhead, and deployment friction start to dominate.

10,000 Bayesian demand models in 70 seconds, with full posterior uncertainty.

Fitting those same 10,000 models sequentially with ARIMA takes ~160 seconds. With Prophet, ~28 minutes. Neither gives you credible intervals for free.

Benchmark

10 parameters, 100,000 observations, 8 chains, 2,000 draws:

Method Time Speedup
rustmc (NUTS) 72s 5.3x
PyMC (NUTS) 383s 1.0x

Batch inference, 10,000 independent 3-parameter models:

Method Total time Per model Uncertainty
rustmc (batch NUTS) 70s 7ms Yes (full posterior)
ARIMA (sequential) 160s 16ms No
Prophet (sequential) 28min 170ms Partial

Quick start

pip install maturin
git clone https://github.com/tbosier/rustmc.git
cd rustmc
python -m venv .venv && source .venv/bin/activate
pip install numpy maturin
maturin develop --manifest-path python_bindings/Cargo.toml --release

or if you prefer, install the published wheel from PyPI:

pip install rustmc

Single model

import numpy as np
import rustmc as rmc

np.random.seed(42)
x = np.random.randn(1000)
y = 2.5 * x + np.random.randn(1000)

builder = rmc.ModelBuilder()
beta = builder.normal_prior("beta", mu=0.0, sigma=1.0)
mu_expr = beta * "x"
builder.normal_likelihood("obs", mu_expr=mu_expr, sigma=1.0, observed_key="y")
model = builder.build()

fit = rmc.sample(model_spec=model, data={"x": x, "y": y}, chains=4, draws=1000)
print(fit.summary())

Output:

4 chains x 1000 draws per chain

Parameter        mean      std     hdi_3%    hdi_97%   ess_bulk   ess_tail    r_hat  mcse_mean
-----------------------------------------------------------------------------------------------
beta           2.4575   0.0313     2.3982     2.5133       2638       2966   1.0055   0.000610
-----------------------------------------------------------------------------------------------
Mean accept rate: 0.94  |  Divergences: 0

Batch inference (10,000 models)

import rustmc as rmc
import numpy as np

models = []
for i in range(10_000):
    builder = rmc.ModelBuilder()
    intercept = builder.normal_prior("intercept", mu=0.0, sigma=200.0)
    trend = builder.normal_prior("trend", mu=0.0, sigma=20.0)
    mu_expr = intercept + trend * "t"
    builder.normal_likelihood("obs", mu_expr=mu_expr, sigma=5.0, observed_key="y")
    model = builder.build()

    t = np.arange(52, dtype=np.float64) / 52
    y = some_data[i]  # your per-SKU time series
    models.append((model, {"t": t, "y": y}))

results = rmc.batch_sample(models, draws=500, warmup=300)

# Each result is a BatchResult with .mean(), .std(), .get_samples()
for r in results[:5]:
    print(r)

Vector parameter model (high-dimensional regression)

For models where the parameter count is large — e.g. a regression with thousands of features — use normal_prior with @ to dispatch X @ beta via faer. rustmc automatically detects that beta is used in a matrix multiply, infers the number of parameters from the matrix dimensions, and promotes it to a contiguous vector parameter block:

import numpy as np
import rustmc as rmc

N, P = 10_000, 500
X = np.random.randn(N, P)           # 2-D array → stored as faer matrix
beta_true = np.random.randn(P)
y = X @ beta_true + np.random.randn(N)

builder = rmc.ModelBuilder()
beta = builder.normal_prior("beta", mu=0.0, sigma=1.0)
mu_expr = beta @ "X"                # auto-promoted to faer GEMV
builder.normal_likelihood("obs", mu_expr=mu_expr, sigma=1.0, observed_key="y")
model = builder.build()

fit = rmc.sample(model_spec=model, data={"X": X, "y": y}, chains=4, draws=500)
print(fit.summary())

Instead of 500 separate scalar graph nodes (one per coefficient), rustmc allocates a single MatVecMul op backed by faer. The entire X @ beta forward pass and its gradient are computed with a single BLAS-level call, giving cache-efficient performance regardless of how many parameters are in the vector.

For explicit control over the vector size, vector_normal_prior("beta", n=P) is also available.

The builder supports scalar hierarchical priors today. For normal_prior, both mu and sigma can be other parameters; for half_normal_prior, sigma can be a parameter; exponential_prior and log_normal_prior also accept parameter-valued hyperparameters; and likelihood sigma or alpha can be parameter-valued as well. Scalar hierarchical normals are automatically compiled through a non-centered path where appropriate. Vector-valued hierarchical priors are not yet supported.

What is implemented

Sampling

  • NUTS (No-U-Turn Sampler) with multinomial candidate selection, generalized U-turn criterion, and divergence detection. Follows Hoffman and Gelman (2014) and Betancourt (2017).
  • HMC with fixed leapfrog steps, available as a fallback via sampler="hmc".
  • Block-structured mass matrix adaptation with 3-phase warmup (step-size only, mass matrix estimation, final step-size tuning).
  • Auto step-size initialization via binary search.
  • Deterministic per-chain RNG (ChaCha8) for reproducible results.
  • Multithreaded chains via Rayon. Batch inference shares the thread pool across all models.

Distributions

Distribution Support Transform Status
Normal (-inf, inf) None Working
StudentT (-inf, inf) None Working
HalfNormal (0, inf) log Working
Exponential (0, inf) log Working
LogNormal (0, inf) log Working
Gamma (0, inf) log Working
Beta (0, 1) logit Working
Uniform (a, b) logit Working
Bernoulli {0, 1} None Discrete, limited
Poisson {0, 1, 2, ...} None Discrete, limited

Constrained distributions are automatically sampled in unconstrained space via log/logit transforms with Jacobian corrections. Samples are back-transformed before being returned to the user.

Discrete priors are exposed for completeness, but they are not differentiable and are not suitable for gradient-based sampling in their current form. In practice, use the continuous relaxations or a model structure that keeps the latent parameters continuous.

Likelihood families

  • normal_likelihood(name, mu_expr, sigma, observed_key)
  • bernoulli_logit_likelihood(name, eta_expr, observed_key)
  • poisson_log_likelihood(name, eta_expr, observed_key)
  • exponential_likelihood(name, eta_expr, observed_key)
  • log_normal_likelihood(name, mu_expr, sigma, observed_key)
  • negative_binomial_likelihood(name, eta_expr, alpha, observed_key)

All GLM-style families use the same expression surface: bare parameters, beta * "x", additive expressions, matrix multiplies via beta @ "X", and additive constants.

Computation

  • Computational graph with reverse-mode automatic differentiation.
  • Fused linear combination op for regression models. Replaces N separate multiply-add passes with a single cache-friendly loop over the data.
  • Zero-allocation evaluator. All vector intermediates are pre-allocated in a flat buffer and reused across gradient evaluations. No heap allocation in the sampling loop.
  • faer-backed matrix-vector multiply (MatVecMul). When a normal_prior parameter is used with @ (e.g. beta @ "X"), rustmc automatically promotes it to a contiguous vector parameter block and dispatches the multiply to faer's GEMV routine. This replaces thousands of individual scalar multiply-add graph ops with a single BLAS-level call. Rayon threads are used for matrices above 100K elements. Explicit vector_normal_prior is also available for manual control.
  • Vectorized Normal prior (VectorNormalLogP). A single graph op evaluates the log-probability of an entire parameter vector under Normal(mu, sigma), replacing one graph node per parameter with a single tight loop. Gradients for all vector parameters accumulate directly into the gradient buffer in one backward pass.
  • 2-D NumPy arrays in the data dict are automatically detected and stored as row-major matrices for use with MatVecMul.

Diagnostics

  • Split R-hat with rank normalization (Vehtari et al. 2021).
  • Bulk and tail effective sample size (ESS).
  • Monte Carlo standard error (MCSE).
  • 94% highest density interval.
  • Per-chain acceptance rates, step sizes, and divergence counts.
  • Automatic warnings for convergence issues.
  • Recovery suite covering canonical synthetic models in CI.

Available via fit.summary() for a formatted table or fit.diagnostics() for programmatic access.

Predictive checks

  • sample_prior_predictive() returns prior draws plus simulated observations.
  • FitResult.posterior_predictive() returns simulated observations from posterior draws.
  • FitResult.log_likelihood() returns pointwise log-likelihood arrays with shape (chain, draw, obs).
  • FitResult.to_arviz() exports posterior, sample stats, posterior predictive, and pointwise log-likelihood for downstream ArviZ/LOO/WAIC workflows.

Progress reporting

Live progress bar rendered from Rust at 10 Hz using atomic counters, with no GIL involvement:

Sampling 8 chains ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% | 24.0k/24.0k | 0 divergences | 384.0k grad evals | 6.7s

Architecture

Python (orchestration only)
  |
  v  GIL released
Rust Core
  +-- Graph         Computational DAG, nodes, ops, data + matrix storage
  +-- Autodiff      Forward evaluation + reverse-mode gradient
  +-- Distributions  Scalar priors, GLM likelihood families, automatic transforms
  +-- NUTS          Multinomial tree-building, U-turn detection
  +-- HMC           Fixed-step leapfrog (fallback)
  +-- Sampler       Multi-chain parallel runner, batch inference
  +-- Diagnostics   R-hat, ESS, MCSE, HDI
  +-- Progress      Atomic counters, background render thread
  +-- faer          BLAS-level MatVecMul for high-dimensional parameter vectors

Design principles:

  • Model graph is built once and shared read-only across chains.
  • Sampler accepts any log-probability + gradient function derived from a Graph.
  • No global state. All state is explicit and owned.
  • Deterministic RNG per chain (ChaCha8 seeded from base_seed + chain_index).
  • Parameter transforms and Jacobian corrections are handled in the graph, not the sampler.

Data structures (Rust vs JAX)

The hot path uses plain Rust types only: the graph is Vec<Node> and Vec<Op>, parameters and gradients are Vec<f64>, and the autodiff evaluator uses contiguous vec_buf / adj_vec_buf (flat Vec<f64>) for all vector intermediates. For high-dimensional parameter vectors, data matrices are stored row-major as Vec<f64> inside the graph and handed to faer's matmul kernel as zero-copy views. ndarray appears only in the Python bindings for converting incoming 2-D NumPy arrays; it is not present in the inner loop. Benefits of this layout:

  • Cache-friendly: One pass over the graph touches sequential memory; vector slots are in a single allocation.
  • Zero allocation in the loop: Buffers are allocated once per chain and reused for every gradient evaluation.
  • No Python or FFI in the inner loop: The entire NUTS/HMC step runs in Rust; Python is only used to build the model and consume results.
  • Fixed graph traversal: The same DAG is walked every time; there is no tracing or recompilation per model or per step.
  • BLAS-level throughput for large parameter vectors: MatVecMul calls faer's GEMV, which uses SIMD intrinsics and can optionally spawn Rayon threads for matrices above 100K elements. A 5,000-parameter vector prior that previously required 5,000 individual scalar multiply-add nodes in the graph is now a single op.

JAX, by contrast, traces Python and compiles to XLA. That gives flexibility and GPU support but adds per-model compilation and dispatch overhead. For many small, independent models (e.g. 10,000 SKUs), rustmc's "compile once, run fixed graph over contiguous buffers" approach often wins on CPU because there is no per-model JAX trace/compile and no Python in the inner loop. Nutpie (JAX-based) is faster than default PyMC for a single model; the batch example compares rustmc's batch NUTS against PyMC+nutpie run in a loop over the same number of models.

Roadmap

Near term:

  • Expose compiled model artifacts as a first-class public workflow in Python and Rust.
  • Extend automatic non-centering beyond scalar hierarchical normals to grouped/vector random effects.
  • Add a benchmark regression harness for wall time, ESS/s, memory, and divergence budgets.
  • Expand the modeling layer with production helpers such as offsets, exposure terms, and panel/hierarchical templates.

Medium term:

  • MAP estimation (L-BFGS)
  • Laplace approximation
  • Sparse indicator variable support
  • Stochastic gradient MCMC (SGLD/SGHMC) for large datasets
  • Model serialization (compile once, deploy without Python)

Long term:

  • Variational inference (ADVI)
  • GPU-accelerated log-probability via wgpu
  • WASM compilation for browser/edge inference
  • Distributed posterior aggregation
  • Automatic reparameterization for funnel geometries
  • C FFI for embedding in non-Python systems

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rustmc-0.8.0.tar.gz (92.1 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

rustmc-0.8.0-cp312-cp312-win_amd64.whl (652.8 kB view details)

Uploaded CPython 3.12Windows x86-64

rustmc-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (910.9 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ x86-64

rustmc-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (845.7 kB view details)

Uploaded CPython 3.12manylinux: glibc 2.17+ ARM64

rustmc-0.8.0-cp312-cp312-macosx_11_0_arm64.whl (734.0 kB view details)

Uploaded CPython 3.12macOS 11.0+ ARM64

rustmc-0.8.0-cp312-cp312-macosx_10_12_x86_64.whl (791.2 kB view details)

Uploaded CPython 3.12macOS 10.12+ x86-64

File details

Details for the file rustmc-0.8.0.tar.gz.

File metadata

  • Download URL: rustmc-0.8.0.tar.gz
  • Upload date:
  • Size: 92.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for rustmc-0.8.0.tar.gz
Algorithm Hash digest
SHA256 bedad63d164d377e2fb6f56df9f21edafd2a571589c481772d498626fdf989fd
MD5 1e7d7944124c8b9fb0fa2f91f6757660
BLAKE2b-256 95189249daceab84c6899386014b63ae7df69157cf858a1721b198473907dd7f

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0.tar.gz:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rustmc-0.8.0-cp312-cp312-win_amd64.whl.

File metadata

  • Download URL: rustmc-0.8.0-cp312-cp312-win_amd64.whl
  • Upload date:
  • Size: 652.8 kB
  • Tags: CPython 3.12, Windows x86-64
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for rustmc-0.8.0-cp312-cp312-win_amd64.whl
Algorithm Hash digest
SHA256 8c6767b7a2c5fcc31c9c64192c235d1e38d128e3ece7f660c5a4f379fd6bf60d
MD5 6b46adecf90b3659e5acc191c4ea002e
BLAKE2b-256 807f19d4ad04eed4715cfd10cadc354d9f0d79d074aa80c94a1083532b68e635

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0-cp312-cp312-win_amd64.whl:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rustmc-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for rustmc-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 c59184b5aefc1b1649ba02130ef1fe97212948bec8e2916e53023e418795801b
MD5 e41e7b7267793e8abddd06bb804dbeb6
BLAKE2b-256 869cc2b997a4d39d0d74187a64400fe4466c1fd84bae2e8467c2f7b5e6323b1b

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rustmc-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl.

File metadata

File hashes

Hashes for rustmc-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm Hash digest
SHA256 64720b8f7362b2290053a4f9917af9d2a7a683690ce8cccefb62f731c5a9d4f7
MD5 af864ecf7793d4432e2d3b5ade58c34e
BLAKE2b-256 dcf681158c1eccfe01adb18a8ed42bc6d5c0b5615d3184d87f355bd4c8ac5b42

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rustmc-0.8.0-cp312-cp312-macosx_11_0_arm64.whl.

File metadata

File hashes

Hashes for rustmc-0.8.0-cp312-cp312-macosx_11_0_arm64.whl
Algorithm Hash digest
SHA256 37b1623e5e444862dc253dac629566e43ea7fdb4bf99db1a159bbc9abd8ff6f7
MD5 f9341b0314f7f87d33ec6f7f1f35057c
BLAKE2b-256 902ef2bf80607ce40178c3458c479381b4d6bffe53bcbc0234b02d797ea2f7cc

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0-cp312-cp312-macosx_11_0_arm64.whl:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file rustmc-0.8.0-cp312-cp312-macosx_10_12_x86_64.whl.

File metadata

File hashes

Hashes for rustmc-0.8.0-cp312-cp312-macosx_10_12_x86_64.whl
Algorithm Hash digest
SHA256 bbbe31e36b6d5399655f4c329b91444641d38401683a909f89e9376a083c4204
MD5 06f24aad712d7872ee1727732c115886
BLAKE2b-256 5dc7475c87393480dd81ceb182701c52201479a4c43a49f4f006c8d8deff3dd7

See more details on using hashes here.

Provenance

The following attestation bundles were made for rustmc-0.8.0-cp312-cp312-macosx_10_12_x86_64.whl:

Publisher: ci.yml on tbosier/rustmc

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page