NumPyro dialect for mental midgets
Project description
🚒 firetruck
!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!
firetruck is NumPyro dialect for mental midgets. This means:
- No
numpyro.deterministicandnumpyro.sample, just write your code like a normal human, and assign variables you want to track toself - You can just return your outcome variable from the function, no
obsbullshit! - Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
- You can deal with latent categorical variables without having to do anything, yaaay!
- WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.
Example
I modified the Waffle House example in the NumPyro docs to use firetruck.
import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd
import firetruck as ftr
DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])
# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
# Just assign variables to self that you want to track,
# And they will be named automatically!!
self.a = dist.Normal(0.0, 0.2)
self.bM = dist.Normal(0.0, 0.5)
self.bA = dist.Normal(0.0, 0.5)
self.sigma = dist.Exponential(0.5)
mu = self.a + self.bM * marriage + self.bA * age
return dist.Normal(mu, self.sigma)
# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)
# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)
# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)
# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)
# Prints this automatically, cause why the hell would you not need this:
# mean std median 5.0% 95.0% n_eff r_hat
# a 0.01 0.20 0.02 -0.31 0.35 2348.61 1.00
# bA 0.17 0.05 0.17 0.10 0.25 1634.25 1.00
# bM 0.26 0.06 0.26 0.16 0.35 1649.36 1.00
# sigma 1.82 0.19 1.81 1.52 2.11 2419.83 1.00
#
# Number of divergences: 0
# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file firetruck-0.1.6.tar.gz.
File metadata
- Download URL: firetruck-0.1.6.tar.gz
- Upload date:
- Size: 7.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-20-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0c12e507e065df19836679fcbfb68f310299c5d8301df0c59afe3a59deb0dbce
|
|
| MD5 |
b27989aefa8101ff0580381620f47327
|
|
| BLAKE2b-256 |
9d041f50e3cfb13851e66b7b2bf16a38bbcf29b991e926cf5e86a44cff17112c
|
File details
Details for the file firetruck-0.1.6-py3-none-any.whl.
File metadata
- Download URL: firetruck-0.1.6-py3-none-any.whl
- Upload date:
- Size: 8.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-20-generic
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bfc694e49e1d0f052a27420b08e38bafda4985f38462ad472b15f314ee1007f2
|
|
| MD5 |
85f003b6ca56c039b7fc9575df5819d0
|
|
| BLAKE2b-256 |
424c0d8df27973acd8c96449596424f19aa6d0f664b494733886c3d01018f7e9
|