Deep Learning for Bayesian Inference: Stochastic Process Simulators
Project description
Stochastic Process Simulators (dl4bi-sps)
Install
Install with the appropriate command. If JAX isn't installed already, we recommend using one of the dl4bi-sps[<jax-version>] installs. The distribution name is dl4bi-sps, and the import package is dl4bi_sps.
pip install dl4bi-sps # import dl4bi_sps
pip install dl4bi-sps[cpu] # import dl4bi_sps + jax for CPU
pip install dl4bi-sps[cuda12] # import dl4bi_sps + jax for CUDA-12
pip install dl4bi-sps[cuda13] # import dl4bi_sps + jax for CUDA-13
View Documentation (Locally)
git clone git@github.com:MLGlobalHealth/sps.git
cd sps
uv sync --extra {cpu,cuda12,cuda13}
uv run --with pdoc pdoc --docformat google --math dl4bi_sps
Demo
import matplotlib.pyplot as plt
import jax
from jax import random
from dl4bi_sps.gp import GP
from dl4bi_sps.priors import Prior
from dl4bi_sps.utils import build_grid
from dl4bi_sps.kernels import matern_3_2, matern_5_2
rng = random.key(42)
s_1d = build_grid([{"start": -2, "stop": 2, "num": 128}])
s_2d = build_grid([{"start": -1.5, "stop": 1.5, "num": 300}, {"start": -2.5, "stop": 2.5, "num": 500}])
batch_size = 1
approx = True
lengthscales = [0.05, 0.1, 0.2]
for name, s in zip(["1d", "2d"], [s_1d, s_2d]):
fig, axes = plt.subplots(len(lengthscales), 1)
for i, ls in enumerate(lengthscales):
gp = GP(matern_3_2, ls=Prior("fixed", {"value": ls}))
f, *_ = gp.simulate(rng, s, batch_size, approx)
axes[i].set_title(f"ls={ls}")
if name == "1d":
axes[i].plot(s, f.squeeze().T)
else:
axes[i].imshow(f.squeeze().reshape(300, 500), cmap="Spectral_r")
plt.tight_layout()
plt.savefig(f"{name}_gp.png", dpi=150)
plt.clf()
# create a simple (forever) dataloader
def dataloader(rng, gp, s, batch_size=64, approx=False):
while True:
rng_i, rng = random.split(rng)
yield gp.simulate(rng_i, s, batch_size, approx)
gp = GP(matern_5_2, ls=Prior("beta", {"a": 2.5, "b": 5}))
loader = dataloader(rng, gp, s, batch_size, approx=True)
f, var, ls, period, z = next(loader)
# within IPython, speed test Kronecker (approx) vs. Cholesky methods
rng, batch_size = random.key(42), 1024
s = build_grid([{"start": 0, "stop": 1, "num": 64}] * 2) # 64x64 grid
%timeit gp.simulate(rng, s, batch_size, approx=True) # ~5 ms
%timeit gp.simulate(rng, s, batch_size, approx=False) # ~50 ms
More examples can be found here.
Gotchas
Small lengthscales can cause numerical instability. Enabling 64-bit floating operations often helps, but it roughly doubles memory usage and may reduce throughput on accelerators.
import jax
# use 64-bit precision globally
jax.config.update("jax_enable_x64", True)
# use 64-bit precision only inside this context manager
with jax.enable_x64():
# Do something in 64-bit precision
...
# Back to default 32-bit precision
Development Setup
- Install uv.
- Clone the repository and
cdinto it:git clone git@github.com:MLGlobalHealth/sps.git && cd sps - Install the pinned Python version if needed:
uv python install - Sync the project, development dependencies, and one JAX extra:
uv sync --extra {cpu,cuda12,cuda13} - Run the test suite:
uv run pytest
uv sync creates a local .venv/ and installs the project in editable mode,
so changes in dl4bi_sps/ are reflected immediately.
Build and Publish to PyPI
Create a local .env file with the publish tokens:
TEST_PYPI_TOKEN=pypi-...
PYPI_TOKEN=pypi-...
Run the release helper from a clean main checkout:
uv run python scripts/release.py .env "Release notes"
The helper bumps the patch version, commits and tags v<version> <message>,
rebuilds dist/, publishes to TestPyPI and PyPI, pushes main and the tag,
and smoke-tests the published install targets for dl4bi-sps.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dl4bi_sps-0.1.2.tar.gz.
File metadata
- Download URL: dl4bi_sps-0.1.2.tar.gz
- Upload date:
- Size: 15.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4259a3edd7653c12783d2a26378d4214afa74fb2bc98e71b0c2da9ae27dabbc2
|
|
| MD5 |
2db18a5bd5d38d5ae689a03bda0a1c1a
|
|
| BLAKE2b-256 |
77d99423c0f96bcc2a27e82b78b394b3346f44ca51d348b7a44639328a1348ec
|
File details
Details for the file dl4bi_sps-0.1.2-py3-none-any.whl.
File metadata
- Download URL: dl4bi_sps-0.1.2-py3-none-any.whl
- Upload date:
- Size: 15.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.11.1 {"installer":{"name":"uv","version":"0.11.1","subcommand":["publish"]},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Void","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
19606644631970cca44e23fe676cf645923d5b0b85190662eef6b9f669105ba6
|
|
| MD5 |
fe2deb8e9a0df0d05d1661294e440e04
|
|
| BLAKE2b-256 |
40700d0a5f4a746c21d603c462505231ea89b7c34973e480257a2f91ab8f5308
|