Skip to main content

NumPyro dialect for mental midgets

Project description

🚒 firetruck

!!! NOTE: This repo is just an experiment for now, not ready for any kind of serious use. The package is not published on PyPI either !!!

firetruck is NumPyro dialect for mental midgets. This means:

  • No numpyro.deterministic and numpyro.sample, just write your code like a normal human, and assign variables you want to track to self
  • You can just return your outcome variable from the function, no obs bullshit!
  • Greatly simplified sampling and VI. No bespoke solutions, just good defaults for 90% of your use cases.
  • You can deal with latent categorical variables without having to do anything, yaaay!
  • WebGL accelarated Plotly plots. You don't know ArViz, Matplotlib or any of that jazz. It not only looks better but it's also interactive and faster.

Example

I modified the Waffle House example in the NumPyro docs to use firetruck.

import jax as jax
import jax.numpy as jnp
import numpyro.distributions as dist
import pandas as pd

import firetruck as ftr

DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
marriage = jnp.array(dset["Marriage"])
divorce = jnp.array(dset["Divorce"])
age = jnp.array(dset["MedianAgeMarriage"])


# Don't forget this decorator!! Very important
@ftr.compact
def model(self, marriage, age):
    # Just assign variables to self that you want to track,
    # And they will be named automatically!!
    self.a = dist.Normal(0.0, 0.2)
    self.bM = dist.Normal(0.0, 0.5)
    self.bA = dist.Normal(0.0, 0.5)
    self.sigma = dist.Exponential(0.5)
    mu = self.a + self.bM * marriage + self.bA * age
    return dist.Normal(mu, self.sigma)


# Sampling Prior predictive distribution
rng_key = jax.random.key(42)
rng_key, subkey = jax.random.split(rng_key)
prior_predictive = model.add_input(marriage, age).sample_predictive(subkey)

# Add inputs to the model and condition on the output
conditioned_model = model.add_input(marriage, age).condition_on(divorce)

# Fit model using meanfield VI
rng_key, subkey = jax.random.split(rng_key)
res = conditioned_model.meanfield_vi(subkey)

# Sample from model using NUTS
rng_key, subkey = jax.random.split(rng_key)
mcmc = conditioned_model.sample_posterior(subkey)

# Prints this automatically, cause why the hell would you not need this:
#                 mean       std    median      5.0%     95.0%     n_eff     r_hat
#          a      0.01      0.20      0.02     -0.31      0.35   2348.61      1.00
#         bA      0.17      0.05      0.17      0.10      0.25   1634.25      1.00
#         bM      0.26      0.06      0.26      0.16      0.35   1649.36      1.00
#      sigma      1.82      0.19      1.81      1.52      2.11   2419.83      1.00
#
# Number of divergences: 0

# Plot sampling trace
fig = ftr.plot_trace(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_forest(mcmc)
fig.show()
image
# Forest plot of posterior samples
fig = ftr.plot_ess(mcmc)
fig.show()
image
# Sampling prior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# NOTE that I'm using the unconditoned model
prior_predictive = model.add_input(marriage, age).sample_predictive(rng_key)
fig = ftr.plot_predictive_check(prior_predictive, obs=divorce)
fig.show()
image
# Sampling posterior predictive and plotting prior-predictive check
rng_key, subkey = jax.random.split(rng_key)
# Note that I'm passing the posterior_samples to the function
posterior_predictive = model.add_input(marriage, age).sample_predictive(
    rng_key, posterior_samples=mcmc.get_samples()
)
fig = ftr.plot_predictive_check(posterior_predictive, obs=divorce)
fig.show()
image

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

firetruck-0.1.3.tar.gz (7.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

firetruck-0.1.3-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file firetruck-0.1.3.tar.gz.

File metadata

  • Download URL: firetruck-0.1.3.tar.gz
  • Upload date:
  • Size: 7.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.3.tar.gz
Algorithm Hash digest
SHA256 f916a86286cf0cdda68dd370cfcb15dbb15158fb106e328c14be0b751cd374c3
MD5 e8a401bbc8bd6f73498fc4abd34282b6
BLAKE2b-256 0ad00faf6e8dce048a976aee2b93f516efa427bee3d9068f3e181179fb1579b4

See more details on using hashes here.

File details

Details for the file firetruck-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: firetruck-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.3.2 CPython/3.12.11 Linux/6.17.0-14-generic

File hashes

Hashes for firetruck-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 4fc7162e9615717b6c23f56cca7c58fc9425484509b328cc5002a1eaf69d8dc6
MD5 9ce5dcbb55396d1ab653ad2f43a1a163
BLAKE2b-256 038c052bde34397aff68e48f214987fd94944ce4ab49539707f7a6fdbe635228

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page