Skip to main content

Deep Learning for Bayesian Inference: Stochastic Process Simulators

Project description

Stochastic Process Simulators (dl4bi-sps)

Install

Install with the appropriate command. If JAX isn't installed already, we recommend using one of the dl4bi-sps[<jax-version>] installs. The distribution name is dl4bi-sps, and the import package is dl4bi_sps.

pip install dl4bi-sps # import dl4bi_sps
pip install dl4bi-sps[cpu] # import dl4bi_sps + jax for CPU
pip install dl4bi-sps[cuda12] # import dl4bi_sps + jax for CUDA-12
pip install dl4bi-sps[cuda13] # import dl4bi_sps + jax for CUDA-13

View Documentation (Locally)

git clone git@github.com:MLGlobalHealth/sps.git
cd sps
uv sync --extra {cpu,cuda12,cuda13}
uv run --with pdoc pdoc --docformat google --math dl4bi_sps

Demo

import matplotlib.pyplot as plt

import jax
from jax import random

from dl4bi_sps.gp import GP
from dl4bi_sps.priors import Prior
from dl4bi_sps.utils import build_grid
from dl4bi_sps.kernels import matern_3_2, matern_5_2

rng = random.key(42)

s_1d = build_grid([{"start": -2, "stop": 2, "num": 128}])
s_2d = build_grid([{"start": -1.5, "stop": 1.5, "num": 300}, {"start": -2.5, "stop": 2.5, "num": 500}])
batch_size = 1
approx = True
lengthscales = [0.05, 0.1, 0.2]
for name, s in zip(["1d", "2d"], [s_1d, s_2d]):
    fig, axes = plt.subplots(len(lengthscales), 1)
    for i, ls in enumerate(lengthscales):
        gp = GP(matern_3_2, ls=Prior("fixed", {"value": ls}))
        f, *_ = gp.simulate(rng, s, batch_size, approx)
        axes[i].set_title(f"ls={ls}")
        if name == "1d":
            axes[i].plot(s, f.squeeze().T)
        else:
            axes[i].imshow(f.squeeze().reshape(300, 500), cmap="Spectral_r")
    plt.tight_layout()
    plt.savefig(f"{name}_gp.png", dpi=150)
    plt.clf()

# create a simple (forever) dataloader
def dataloader(rng, gp, s, batch_size=64, approx=False):
    while True:
        rng_i, rng = random.split(rng)
        yield gp.simulate(rng_i, s, batch_size, approx)


gp = GP(matern_5_2, ls=Prior("beta", {"a": 2.5, "b": 5}))
loader = dataloader(rng, gp, s, batch_size, approx=True)
f, var, ls, period, z = next(loader)


# within IPython, speed test Kronecker (approx) vs. Cholesky methods 
rng, batch_size = random.key(42), 1024
s = build_grid([{"start": 0, "stop": 1, "num": 64}] * 2) # 64x64 grid
%timeit gp.simulate(rng, s, batch_size, approx=True) # ~5 ms
%timeit gp.simulate(rng, s, batch_size, approx=False) # ~50 ms

More examples can be found here.

Gotchas

Small lengthscales can cause numerical instability. Enabling 64-bit floating operations often helps, but it roughly doubles memory usage and may reduce throughput on accelerators.

import jax
# use 64-bit precision globally
jax.config.update("jax_enable_x64", True)
# use 64-bit precision only inside this context manager
with jax.enable_x64():
    # Do something in 64-bit precision
    ...
# Back to default 32-bit precision

Development Setup

  • Install uv.
  • Clone the repository and cd into it: git clone git@github.com:MLGlobalHealth/sps.git && cd sps
  • Install the pinned Python version if needed: uv python install
  • Sync the project, development dependencies, and one JAX extra: uv sync --extra {cpu,cuda12,cuda13}
  • Run the test suite: uv run pytest

uv sync creates a local .venv/ and installs the project in editable mode, so changes in dl4bi_sps/ are reflected immediately.

Build and Publish to PyPI

  1. Bump the package version:
uv version --bump patch --frozen
  1. Build the source distribution and wheel:
uv build --no-sources
  1. Publish to TestPyPI first:
UV_PUBLISH_TOKEN=$TEST_PYPI_TOKEN uv publish \
  --publish-url https://test.pypi.org/legacy/ \
  --check-url https://test.pypi.org/simple/
  1. After validating the release, publish the same artifacts to PyPI:
UV_PUBLISH_TOKEN=$PYPI_TOKEN uv publish
  1. Smoke-test the published install targets in fresh environments:
uv run --isolated --with "dl4bi-sps==<version>" --no-project -- python -c "import dl4bi_sps"
uv run --isolated --with "dl4bi-sps[cpu]==<version>" --no-project -- python -c "import dl4bi_sps"
uv run --isolated --with "dl4bi-sps[cuda12]==<version>" --no-project -- python -c "import dl4bi_sps"
uv run --isolated --with "dl4bi-sps[cuda13]==<version>" --no-project -- python -c "import dl4bi_sps"

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dl4bi_sps-0.1.1.tar.gz (15.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dl4bi_sps-0.1.1-py3-none-any.whl (15.9 kB view details)

Uploaded Python 3

File details

Details for the file dl4bi_sps-0.1.1.tar.gz.

File metadata

  • Download URL: dl4bi_sps-0.1.1.tar.gz
  • Upload date:
  • Size: 15.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dl4bi_sps-0.1.1.tar.gz
Algorithm Hash digest
SHA256 8d58732401b79ce6722cc9e7168b4cbb0dbf444959f0757d47a84b03e660f253
MD5 f4ff768eaf3232188546e6c7c2aad3ab
BLAKE2b-256 2aad9e2f6e87c092a9339fb1e6764aa2d60103a1b335b35b63927f39e9a5e763

See more details on using hashes here.

File details

Details for the file dl4bi_sps-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: dl4bi_sps-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 15.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for dl4bi_sps-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 95dfbc5288676448308ba97f9eb40d633bb15cce7efe5faf49f66875d1f62416
MD5 86907e60be158c191edb48560dd66642
BLAKE2b-256 f7c4db6df181944200d3a0c056723a8a4a85a701f6fb7e129f9df67118c6cd03

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page