Skip to main content

Deep Learning for Bayesian Inference: Stochastic Process Simulators

Project description

Stochastic Process Simulators (dl4bi-sps)

Install

Install with the appropriate command. If JAX isn't installed already, we recommend using one of the dl4bi-sps[<jax-version>] installs. The distribution name is dl4bi-sps, and the import package is dl4bi_sps.

pip install dl4bi-sps # import dl4bi_sps
pip install dl4bi-sps[cpu] # import dl4bi_sps + jax for CPU
pip install dl4bi-sps[cuda12] # import dl4bi_sps + jax for CUDA-12
pip install dl4bi-sps[cuda13] # import dl4bi_sps + jax for CUDA-13

View Documentation (Locally)

git clone git@github.com:MLGlobalHealth/sps.git
cd sps
uv sync --extra {cpu,cuda12,cuda13}
uv run --with pdoc pdoc --docformat google --math dl4bi_sps

Demo

import matplotlib.pyplot as plt

import jax
from jax import random

from dl4bi_sps.gp import GP
from dl4bi_sps.priors import Prior
from dl4bi_sps.utils import build_grid
from dl4bi_sps.kernels import matern_3_2, matern_5_2

rng = random.key(42)

s_1d = build_grid([{"start": -2, "stop": 2, "num": 128}])
s_2d = build_grid([{"start": -1.5, "stop": 1.5, "num": 300}, {"start": -2.5, "stop": 2.5, "num": 500}])
batch_size = 1
approx = True
lengthscales = [0.05, 0.1, 0.2]
for name, s in zip(["1d", "2d"], [s_1d, s_2d]):
    fig, axes = plt.subplots(len(lengthscales), 1)
    for i, ls in enumerate(lengthscales):
        gp = GP(matern_3_2, ls=Prior("fixed", {"value": ls}))
        f, *_ = gp.simulate(rng, s, batch_size, approx)
        axes[i].set_title(f"ls={ls}")
        if name == "1d":
            axes[i].plot(s, f.squeeze().T)
        else:
            axes[i].imshow(f.squeeze().reshape(300, 500), cmap="Spectral_r")
    plt.tight_layout()
    plt.savefig(f"{name}_gp.png", dpi=150)
    plt.clf()

# create a simple (forever) dataloader
def dataloader(rng, gp, s, batch_size=64, approx=False):
    while True:
        rng_i, rng = random.split(rng)
        yield gp.simulate(rng_i, s, batch_size, approx)


gp = GP(matern_5_2, ls=Prior("beta", {"a": 2.5, "b": 5}))
loader = dataloader(rng, gp, s, batch_size, approx=True)
f, var, ls, period, z = next(loader)


# within IPython, speed test Kronecker (approx) vs. Cholesky methods 
rng, batch_size = random.key(42), 1024
s = build_grid([{"start": 0, "stop": 1, "num": 64}] * 2) # 64x64 grid
%timeit gp.simulate(rng, s, batch_size, approx=True) # ~5 ms
%timeit gp.simulate(rng, s, batch_size, approx=False) # ~50 ms

More examples can be found here.

Gotchas

Small lengthscales can cause numerical instability. Enabling 64-bit floating operations often helps, but it roughly doubles memory usage and may reduce throughput on accelerators.

import jax
# use 64-bit precision globally
jax.config.update("jax_enable_x64", True)
# use 64-bit precision only inside this context manager
with jax.enable_x64():
    # Do something in 64-bit precision
    ...
# Back to default 32-bit precision

Development Setup

  • Install uv.
  • Clone the repository and cd into it: git clone git@github.com:MLGlobalHealth/sps.git && cd sps
  • Install the pinned Python version if needed: uv python install
  • Sync the project, development dependencies, and one JAX extra: uv sync --extra {cpu,cuda12,cuda13}
  • Run the test suite: uv run pytest

uv sync creates a local .venv/ and installs the project in editable mode, so changes in dl4bi_sps/ are reflected immediately.

Build and Publish to PyPI

Create a local .env file with the publish tokens:

TEST_PYPI_TOKEN=pypi-...
PYPI_TOKEN=pypi-...

Run the release helper from a clean main checkout:

uv run python scripts/release.py .env "Release notes"

The helper bumps the patch version, commits and tags v<version> <message>, rebuilds dist/, publishes to TestPyPI and PyPI, pushes main and the tag, and smoke-tests the published install targets for dl4bi-sps.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dl4bi_sps-0.1.2.tar.gz (15.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dl4bi_sps-0.1.2-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file dl4bi_sps-0.1.2.tar.gz.

File metadata

  • Download URL: dl4bi_sps-0.1.2.tar.gz
  • Upload date:
  • Size: 15.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dl4bi_sps-0.1.2.tar.gz
Algorithm Hash digest
SHA256 4259a3edd7653c12783d2a26378d4214afa74fb2bc98e71b0c2da9ae27dabbc2
MD5 2db18a5bd5d38d5ae689a03bda0a1c1a
BLAKE2b-256 77d99423c0f96bcc2a27e82b78b394b3346f44ca51d348b7a44639328a1348ec

See more details on using hashes here.

File details

Details for the file dl4bi_sps-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: dl4bi_sps-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dl4bi_sps-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 19606644631970cca44e23fe676cf645923d5b0b85190662eef6b9f669105ba6
MD5 fe2deb8e9a0df0d05d1661294e440e04
BLAKE2b-256 40700d0a5f4a746c21d603c462505231ea89b7c34973e480257a2f91ab8f5308

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page