Skip to main content

NumPyro dialect for mental midgets

Project description

🚒 firetruck

!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!

firetruck is NumPyro dialect for mental midgets. This means:

  • No numpyro.deterministic and numpyro.sample, just write your code like a normal human, and assign variables you want to track to self
  • You can just return your outcome variable from the function, no obs bullshit!
  • Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
  • You can deal with latent categorical variables without having to do anything, yaaay!
  • WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.

Example

I modified the Waffle House example in the NumPyro docs to use firetruck.

import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd

import firetruck as ftr

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])


# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
    # Just assign variables to self that you want to track,
    # And they will be named automatically!!
    self.a = dist.Normal(0.0, 0.2)
    self.bM = dist.Normal(0.0, 0.5)
    self.bA = dist.Normal(0.0, 0.5)
    self.sigma = dist.Exponential(0.5)
    mu = self.a + self.bM * marriage + self.bA * age
    return dist.Normal(mu, self.sigma)


# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)

# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)

# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)

# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)

# Prints this automatically, cause why the hell would you not need this:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          a      0.01      0.20      0.02     -0.31      0.35   2348.61      1.00
#         bA      0.17      0.05      0.17      0.10      0.25   1634.25      1.00
#         bM      0.26      0.06      0.26      0.16      0.35   1649.36      1.00
#      sigma      1.82      0.19      1.81      1.52      2.11   2419.83      1.00
#
# Number of divergences: 0

# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
image
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
image
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
    rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
image

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

firetruck-0.1.0.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

firetruck-0.1.0-py3-none-any.whl (7.8 kB view details)

Uploaded Python 3

File details

Details for the file firetruck-0.1.0.tar.gz.

File metadata

  • Download URL: firetruck-0.1.0.tar.gz
  • Upload date:
  • Size: 6.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.0.tar.gz
Algorithm Hash digest
SHA256 292acd66723eb27ee974576ab8b29214774ff0578346d7761f7a40f8014e8502
MD5 25dcc7abef8d56a2edf45bfdacd38d56
BLAKE2b-256 aa41a0ac3523aec83a1a3edc6f9e1b9d0367b4286eb7ef720a68643f23d42df2

See more details on using hashes here.

File details

Details for the file firetruck-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: firetruck-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ee4b1c195876859f64db8d6a4996d51fe28b177f365f436c63ac805b916798ef
MD5 11b4f8845360ded3225e6c492ad73bfb
BLAKE2b-256 c87f2770d8b94257dfc2240e27f797c075fe53ea37a416d39adff5f3101a49ba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page