Skip to main content

A JAX/NumPyro port of Pyro's forecasting module.

Project description

NumPyro Forecast

PyPI version ci docs Ruff License

A JAX/NumPyro port of Pyro's forecasting module.

📖 Documentation: https://juanitorduz.github.io/numpyro_forecast/

Scope

numpyro_forecast is a small, focused toolkit for Bayesian time-series forecasting with NumPyro. You write the generative model; the package handles the train/forecast plumbing, inference, and evaluation:

  • A single model both trains and forecasts. In-sample time latents use a fixed site name (drift); the forecast horizon uses a separate _future site so the variational guide is never resized and the forecast suffix is drawn from the prior. The horizon is inferred from shapes, covariates longer than data.
  • Two inference backends: stochastic variational inference (Forecaster, via AutoNormal) and Hamiltonian Monte Carlo / NUTS (HMCForecaster).
  • Backtesting over rolling windows plus probabilistic and point metrics.
  • Works for univariate, multivariate and hierarchical models.

Arrays follow Pyro's layout: time at axis -2, the observation/event dimension at -1, and batch dimensions to the left.

It is not an AutoML or "fit-any-series" library — there are no pre-built model zoo or automatic feature pipelines. You define the NumPyro model; the package gives you a clean path from model to forecasts and scores.

Installation

Requires Python >= 3.12. Install from PyPI:

uv add numpyro_forecast
# or, with pip:
pip install numpyro_forecast

To install the latest development version from source:

uv add "numpyro_forecast @ git+https://github.com/juanitorduz/numpyro_forecast"
# or, with pip:
pip install "numpyro_forecast @ git+https://github.com/juanitorduz/numpyro_forecast"

For a local checkout:

uv sync --all-extras   # or: pip install -e ".[dataframes]"

The optional dataframes extra adds pandas/polars support.

Quickstart

Define a model, fit it with SVI, and draw probabilistic forecasts:

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer.reparam import LocScaleReparam

from numpyro_forecast.evaluate import eval_crps
from numpyro_forecast.forecaster import Forecaster, ForecastingModel
from numpyro_forecast.util import fourier_features


class SeasonalForecaster(ForecastingModel):
    """Local-level random walk + Fourier seasonality, Student-T noise."""

    def model(self, zero_data, covariates):
        num_features = covariates.shape[-1]
        bias = numpyro.sample("bias", dist.Normal(0.0, 10.0))
        weight = numpyro.sample(
            "weight", dist.Normal(0.0, 0.1).expand([num_features]).to_event(1)
        )
        drift_scale = numpyro.sample("drift_scale", dist.LogNormal(-3.0, 1.0))
        sigma = numpyro.sample("sigma", dist.LogNormal(-2.0, 1.0))
        nu = numpyro.sample("nu", dist.Gamma(10.0, 2.0))

        drift = self.time_series(
            "drift",
            lambda: dist.Normal(0.0, drift_scale),
            reparam=LocScaleReparam(0),
        )
        level = jnp.cumsum(drift, axis=-2)  # random-walk level
        regression = (weight * covariates).sum(axis=-1, keepdims=True)
        prediction = level + bias + regression

        self.predict(dist.StudentT(df=nu, loc=0.0, scale=sigma), prediction)


# Synthetic weekly-seasonal series: time at axis -2, one observation dim at -1.
period, t_obs, horizon = 52.0, 156, 26
duration = t_obs + horizon
covariates = fourier_features(duration, period=period, num_terms=3)
t = jnp.arange(duration)[:, None]
truth = jnp.sin(2 * jnp.pi * t / period) + 0.01 * t
data = truth[:t_obs]

key_fit, key_pred = random.split(random.PRNGKey(0))
forecaster = Forecaster(
    key_fit,
    SeasonalForecaster(),
    data,
    covariates[:t_obs],
    num_steps=1_500,
)

# Draw 100 forecast samples over the held-out horizon, shaped (sample, future, obs).
samples = forecaster(key_pred, data, covariates, num_samples=100)
print("forecast samples:", samples.shape)
print("CRPS:", eval_crps(samples, truth[t_obs:]))

Two APIs: functional core and OOP shim

The package is built around a pure functional core (numpyro_forecast.functional) and a thin object-oriented shim (numpyro_forecast.forecaster) that ports Pyro's class-based API. The two are fully interchangeable: both produce the same NumPyro model callable (covariates, data=None) and consume the same posterior dict of latent draws, so you can fit with one and forecast with the other.

  • Functional core. The train/forecast split is an explicit, immutable Horizon value (derived from the covariate and data shapes) that is threaded into pure primitives. You write a model body (Horizon, covariates) -> None that calls time_series(...) and predict(...), wrap it with forecasting_model(...), and drive inference with the free functions fit_svi / fit_mcmc, draw_posterior, and forecast. No global parameter store, explicit PRNGKey threading.
  • OOP shim (Pyro-compatible). Subclass ForecastingModel and implement model(self, zero_data, covariates), calling self.time_series(...) and self.predict(...) exactly as in Pyro's pyro.contrib.forecast. The Forecaster (SVI) and HMCForecaster (NUTS) classes carry the horizon as instance state and delegate to the functional core under the hood. This is the API used in the Quickstart above.

The same SeasonalForecaster model, written and run through the functional API:

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer.reparam import LocScaleReparam

from numpyro_forecast.evaluate import eval_crps
from numpyro_forecast.functional import (
    Horizon,
    draw_posterior,
    fit_svi,
    forecast,
    forecasting_model,
    predict,
    time_series,
)
from numpyro_forecast.util import fourier_features


def seasonal_body(h: Horizon, covariates):
    """Local-level random walk + Fourier seasonality, Student-T noise."""
    num_features = covariates.shape[-1]
    bias = numpyro.sample("bias", dist.Normal(0.0, 10.0))
    weight = numpyro.sample(
        "weight", dist.Normal(0.0, 0.1).expand([num_features]).to_event(1)
    )
    drift_scale = numpyro.sample("drift_scale", dist.LogNormal(-3.0, 1.0))
    sigma = numpyro.sample("sigma", dist.LogNormal(-2.0, 1.0))
    nu = numpyro.sample("nu", dist.Gamma(10.0, 2.0))

    drift = time_series(
        h, "drift", lambda: dist.Normal(0.0, drift_scale), reparam=LocScaleReparam(0)
    )
    level = jnp.cumsum(drift, axis=-2)  # random-walk level
    regression = (weight * covariates).sum(axis=-1, keepdims=True)
    prediction = level + bias + regression

    predict(h, dist.StudentT(df=nu, loc=0.0, scale=sigma), prediction)


# Same synthetic series as the Quickstart.
period, t_obs, horizon = 52.0, 156, 26
duration = t_obs + horizon
covariates = fourier_features(duration, period=period, num_terms=3)
t = jnp.arange(duration)[:, None]
truth = jnp.sin(2 * jnp.pi * t / period) + 0.01 * t
data = truth[:t_obs]

model = forecasting_model(seasonal_body)
key_fit, key_post, key_pred = random.split(random.PRNGKey(0), 3)

fit = fit_svi(key_fit, model, data, covariates[:t_obs], num_steps=1_500)
posterior = draw_posterior(key_post, fit, num_samples=100)
samples = forecast(key_pred, model, posterior, data, covariates)
print("forecast samples:", samples.shape)
print("CRPS:", eval_crps(samples, truth[t_obs:]))

Development

This project uses uv for environment management, ruff for linting/formatting, ty for type checking, and prek to run the pre-commit hooks.

uv sync --all-extras       # create the environment
prek install               # install git hooks
prek run --all-files       # lint + format + type check
uv run pytest              # run the tests

See CONTRIBUTING.md for the full workflow and guidelines.

License

Apache-2.0.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

numpyro_forecast-0.1.2.tar.gz (14.8 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

numpyro_forecast-0.1.2-py3-none-any.whl (36.9 kB view details)

Uploaded Python 3

File details

Details for the file numpyro_forecast-0.1.2.tar.gz.

File metadata

  • Download URL: numpyro_forecast-0.1.2.tar.gz
  • Upload date:
  • Size: 14.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for numpyro_forecast-0.1.2.tar.gz
Algorithm Hash digest
SHA256 77522d2eacfd436fef63e32876bf75c1dedd9f1eeedfeee764deec1b9b1ccca8
MD5 2ecf8224ddcd94d3da78e03253e3aaa1
BLAKE2b-256 4c5cac1acba80361607a03c77e51b9ff72d922489b0a0f90d19c28f99629f30c

See more details on using hashes here.

Provenance

The following attestation bundles were made for numpyro_forecast-0.1.2.tar.gz:

Publisher: release.yml on juanitorduz/numpyro_forecast

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file numpyro_forecast-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for numpyro_forecast-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e7001d787e5415df304591ecd66a5a5a1ff9c6c5660c452d43583f7180aee6c2
MD5 529f9b967150a02eb72773d832703e17
BLAKE2b-256 f20bf433cadd6f2b8382995a8b3ac8f83e69ee6c98c72cac4aeea35f5ff7e859

See more details on using hashes here.

Provenance

The following attestation bundles were made for numpyro_forecast-0.1.2-py3-none-any.whl:

Publisher: release.yml on juanitorduz/numpyro_forecast

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page