Skip to main content

NumPyro dialect for mental midgets

Project description

🚒 firetruck

!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!

firetruck is NumPyro dialect for mental midgets. This means:

  • No numpyro.deterministic and numpyro.sample, just write your code like a normal human, and assign variables you want to track to self
  • You can just return your outcome variable from the function, no obs bullshit!
  • Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
  • You can deal with latent categorical variables without having to do anything, yaaay!
  • WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.

Example

I modified the Waffle House example in the NumPyro docs to use firetruck.

import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd

import firetruck as ftr

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])


# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
    # Just assign variables to self that you want to track,
    # And they will be named automatically!!
    self.a = dist.Normal(0.0, 0.2)
    self.bM = dist.Normal(0.0, 0.5)
    self.bA = dist.Normal(0.0, 0.5)
    self.sigma = dist.Exponential(0.5)
    mu = self.a + self.bM * marriage + self.bA * age
    return dist.Normal(mu, self.sigma)


# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)

# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)

# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)

# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)

# Prints this automatically, cause why the hell would you not need this:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          a      0.01      0.20      0.02     -0.31      0.35   2348.61      1.00
#         bA      0.17      0.05      0.17      0.10      0.25   1634.25      1.00
#         bM      0.26      0.06      0.26      0.16      0.35   1649.36      1.00
#      sigma      1.82      0.19      1.81      1.52      2.11   2419.83      1.00
#
# Number of divergences: 0

# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
image
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
image
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
    rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
image

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

firetruck-0.1.4.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

firetruck-0.1.4-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file firetruck-0.1.4.tar.gz.

File metadata

  • Download URL: firetruck-0.1.4.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.4.tar.gz
Algorithm Hash digest
SHA256 f9f7c25a797c40fe0f9c3f399c2abc70bf9b443c174adaedde6f5e99c7ae2e33
MD5 6be0f819f75683506d80049ade4bbebd
BLAKE2b-256 a030bf2cb3e2acee9c4a4cddc7f36696a147c29f3922097382ff8a93fe7da898

See more details on using hashes here.

File details

Details for the file firetruck-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: firetruck-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 7b97c771afcc3f516023f8ec80e18c679af7eeebedb0409246308fdcd3ca81cc
MD5 7c8ecc55bf4a35434ca8f16233e0fadb
BLAKE2b-256 e09de92db177ab58a520229c662ea97bb1fedf44f0b1247b52c09fe64cf1ad1b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page