Skip to main content

NumPyro dialect for mental midgets

Project description

🚒 firetruck

!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!

firetruck is NumPyro dialect for mental midgets. This means:

  • No numpyro.deterministic and numpyro.sample, just write your code like a normal human, and assign variables you want to track to self
  • You can just return your outcome variable from the function, no obs bullshit!
  • Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
  • You can deal with latent categorical variables without having to do anything, yaaay!
  • WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.

Example

I modified the Waffle House example in the NumPyro docs to use firetruck.

import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd

import firetruck as ftr

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])


# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
    # Just assign variables to self that you want to track,
    # And they will be named automatically!!
    self.a = dist.Normal(0.0, 0.2)
    self.bM = dist.Normal(0.0, 0.5)
    self.bA = dist.Normal(0.0, 0.5)
    self.sigma = dist.Exponential(0.5)
    mu = self.a + self.bM * marriage + self.bA * age
    return dist.Normal(mu, self.sigma)


# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)

# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)

# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)

# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)

# Prints this automatically, cause why the hell would you not need this:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          a      0.01      0.20      0.02     -0.31      0.35   2348.61      1.00
#         bA      0.17      0.05      0.17      0.10      0.25   1634.25      1.00
#         bM      0.26      0.06      0.26      0.16      0.35   1649.36      1.00
#      sigma      1.82      0.19      1.81      1.52      2.11   2419.83      1.00
#
# Number of divergences: 0

# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
image
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
image
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
    rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
image

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

firetruck-0.1.5.tar.gz (7.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

firetruck-0.1.5-py3-none-any.whl (8.8 kB view details)

Uploaded Python 3

File details

Details for the file firetruck-0.1.5.tar.gz.

File metadata

  • Download URL: firetruck-0.1.5.tar.gz
  • Upload date:
  • Size: 7.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-20-generic

File hashes

Hashes for firetruck-0.1.5.tar.gz
Algorithm Hash digest
SHA256 c625bfc1bd4e5be476014df9c16e9d1d890626c1681558eea54530027a18085d
MD5 96719749780a5941a56fdf5effb0637f
BLAKE2b-256 774c7b4925d7c52c2316e966775d1ecde07d753654ae06f549acc8daf925e08a

See more details on using hashes here.

File details

Details for the file firetruck-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: firetruck-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 8.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-20-generic

File hashes

Hashes for firetruck-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 7ecea95c79ddaeb8aad5918cc5cf77db3ed8904aabaf6aa855b2826e716f36f2
MD5 eca93bbf6827f5c28234c09ce2fd8777
BLAKE2b-256 c1e6870cb2e25f8abb954a973b2242bf62006e2e5760335fd0172ed00dbb3a11

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page