Skip to main content

NumPyro dialect for mental midgets

Project description

🚒 firetruck

!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!

firetruck is NumPyro dialect for mental midgets. This means:

  • No numpyro.deterministic and numpyro.sample, just write your code like a normal human, and assign variables you want to track to self
  • You can just return your outcome variable from the function, no obs bullshit!
  • Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
  • You can deal with latent categorical variables without having to do anything, yaaay!
  • WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.

Example

I modified the Waffle House example in the NumPyro docs to use firetruck.

import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd

import firetruck as ftr

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])


# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
    # Just assign variables to self that you want to track,
    # And they will be named automatically!!
    self.a = dist.Normal(0.0, 0.2)
    self.bM = dist.Normal(0.0, 0.5)
    self.bA = dist.Normal(0.0, 0.5)
    self.sigma = dist.Exponential(0.5)
    mu = self.a + self.bM * marriage + self.bA * age
    return dist.Normal(mu, self.sigma)


# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)

# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)

# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)

# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)

# Prints this automatically, cause why the hell would you not need this:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          a      0.01      0.20      0.02     -0.31      0.35   2348.61      1.00
#         bA      0.17      0.05      0.17      0.10      0.25   1634.25      1.00
#         bM      0.26      0.06      0.26      0.16      0.35   1649.36      1.00
#      sigma      1.82      0.19      1.81      1.52      2.11   2419.83      1.00
#
# Number of divergences: 0

# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
image
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
image
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
    rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
image

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

firetruck-0.1.1.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

firetruck-0.1.1-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file firetruck-0.1.1.tar.gz.

File metadata

  • Download URL: firetruck-0.1.1.tar.gz
  • Upload date:
  • Size: 6.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.1.tar.gz
Algorithm Hash digest
SHA256 aacf779830828e743254ec985e02ea5bbd4bd16a42c06552471ce6917d04b4e6
MD5 d75e8173a822e9509596a79c3f2c2563
BLAKE2b-256 e8df71609d2f48d437c9255804b1aa3a0d7ae144ea6b1da8dc697ef2d5d66e1d

See more details on using hashes here.

File details

Details for the file firetruck-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: firetruck-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 7.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 bae11e32d2a2485256b9cef7b5122e878aa7b5bbc788736530e4550d5dbb2ca1
MD5 6589c7a009c53ea4d610f1767d998b05
BLAKE2b-256 38e995d29520e860203bf688a2de0e236b96053cc2b37584a6c28fe1bb5c8e17

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page