Skip to main content

A convenient object-oriented wrapper for working with numpyro models.

Project description

An object-oriented interface to numpyro

CI PyPI version

This package provides a wrapper for working with numpyro models. It aims to remain model-agnostic, but package up a lot of the model fitting code to reduce repetition.

It is intended to make life a bit easier for people who are already familiar with Numpyro and Bayesian modelling. It is not intended to fulfil the same high-level wrapper role as packages such as brms. The user is still required to write the model from scratch. This is an intentional choice: writing the model from scratch takes longer and is less convenient for standard models like GLMs, but has the advantage that one gains a deeper insight into what is happening under the hood, and also it is more transparent to implement custom models that don't fit a standard mould.

Getting started

pip install numpyro-oop

The basic idea is that the user defines a new class that inherits from BaseNumpyroModel, and defines (minimally) the model to be fit by overwriting the model method:

from numpyro_oop import BaseNumpyroModel

class DemoModel(BaseNumpyroModel):
    def model(self, data=None, ...):
        ...

m1 = DemoModel(data=df, seed=42)

Then all other sampling and prediction steps are handled by numpyro-oop, or related libraries (e.g. arviz):

# sample from the model:
m1.sample()  
# generate model predictions for the dataset given at initialization:
preds = m1.predict(...)
# generate an Arviz InferenceData object stored in self.arviz_data:
m1.generate_arviz_data()  

A complete demo can be found in /scripts/demo_1.ipynb.

Requirements of the model method

Consider the following model method:

class DemoModel(BaseNumpyroModel):
    def model(self, data=None, sample_conditional=True, ...):
        ...

        if sample_conditional:
            obs = data["y"].values
        else:
            obs = None
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=obs)

m1 = DemoModel(data=df, seed=42)

First, note that we can pass data as an optional kwarg that defaults to None. If data is not passed to the model object directly then model will automatically fall back to self.data, defined when the class instance is initialised.

Second, note the sample_conditional argument and subsequent pattern. To use Numpyro's Predictive method, we need the ability to set any observed data that a sampling distribution is conditioned upon (typically the likelihood) to be None. See the Numpyro docs for examples. Currently, numpyro-oop requires this to be implemented by the user in the model definition in some way; a suggested pattern is shown above.

After the model is sampled, we can then generate posterior predictive distributions by passing sample_conditional=False as a model_kwarg:

m1.predict(model_kwargs={"sample_conditional": False})

Using reparameterizations

One of the really neat features of Numpyro is the ability to define reparameterizations of variables that can be applied to the model object (see docs). To use these with numpyro-oop, the user must overwrite the generate_reparam_config method of BaseNumpyroModel to return a reparameterization dictionary:

def generate_reparam_config(self) -> dict:
    reparam_config = {
        "theta": LocScaleReparam(0),
    }
    return reparam_config

In this example, the node theta in the model will be reparameterized with a location/scale reparam if use_reparam=True when the class instance is created. This is handy, because you can then test the effect of your reparameterization by simply setting use_reparam=False and re-fitting the model. See examples/demo.ipynb for a working example.

Roadmap after initial release

  • include doctest, improved examples
  • demo and tests for multiple group variables
  • export docs to some static page (readthedocs or similar); detail info on class methods and attributes
  • Contributor guidelines
  • Fix type hints via linter checks

Development notes

Install the library with development dependencies via pip install -e ".[dev]".

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numpyro_oop-0.1.0.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

numpyro_oop-0.1.0-py3-none-any.whl (9.6 kB view details)

Uploaded Python 3

File details

Details for the file numpyro_oop-0.1.0.tar.gz.

File metadata

  • Download URL: numpyro_oop-0.1.0.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for numpyro_oop-0.1.0.tar.gz
Algorithm Hash digest
SHA256 fa7eccba0ebd20892c8299bef14530590300e6891f105a7224dcc9faf301a438
MD5 51b59693542030a2acefd04946566bad
BLAKE2b-256 b0beb17ce8eeca6e495b0a58016648c2148859e3d4c3c1fc9126625475332bd1

See more details on using hashes here.

File details

Details for the file numpyro_oop-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: numpyro_oop-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for numpyro_oop-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 5fcd9c07495634cf453c713412928ed209cc1bb96102b426cdfad7ac34a425bc
MD5 84f14dfb541c3a746955ee4be6c28ad5
BLAKE2b-256 c4ec7788febf93123fea25946a5c53b94aa480c773b808a38175984e075b2107

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page