A JAX/NumPyro port of Pyro's forecasting module.
Project description
NumPyro Forecast
A JAX/NumPyro port of Pyro's forecasting module.
Scope
numpyro_forecast is a small, focused toolkit for Bayesian time-series
forecasting with NumPyro. You write the generative model; the package handles
the train/forecast plumbing, inference, and evaluation:
- A single model both trains and forecasts. In-sample time latents use a fixed
site name (
drift); the forecast horizon uses a separate_futuresite so the variational guide is never resized and the forecast suffix is drawn from the prior. The horizon is inferred from shapes,covariateslonger thandata. - Two inference backends: stochastic variational inference (
Forecaster, viaAutoNormal) and Hamiltonian Monte Carlo / NUTS (HMCForecaster). - Backtesting over rolling windows plus probabilistic and point metrics.
- Works for univariate, multivariate and hierarchical models.
Arrays follow Pyro's layout: time at axis -2, the observation/event
dimension at -1, and batch dimensions to the left.
It is not an AutoML or "fit-any-series" library — there are no pre-built model zoo or automatic feature pipelines. You define the NumPyro model; the package gives you a clean path from model to forecasts and scores.
Installation
Requires Python >= 3.12. The package is not yet published on PyPI; install it from source:
uv add "numpyro_forecast @ git+https://github.com/juanitorduz/numpyro_forecast"
# or, with pip:
pip install "numpyro_forecast @ git+https://github.com/juanitorduz/numpyro_forecast"
For a local checkout:
uv sync --all-extras # or: pip install -e ".[dataframes]"
The optional dataframes extra adds pandas/polars support.
Quickstart
Define a model, fit it with SVI, and draw probabilistic forecasts:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer.reparam import LocScaleReparam
from numpyro_forecast.evaluate import eval_crps
from numpyro_forecast.forecaster import Forecaster, ForecastingModel
from numpyro_forecast.util import fourier_features
class SeasonalForecaster(ForecastingModel):
"""Local-level random walk + Fourier seasonality, Student-T noise."""
def model(self, zero_data, covariates):
num_features = covariates.shape[-1]
bias = numpyro.sample("bias", dist.Normal(0.0, 10.0))
weight = numpyro.sample(
"weight", dist.Normal(0.0, 0.1).expand([num_features]).to_event(1)
)
drift_scale = numpyro.sample("drift_scale", dist.LogNormal(-3.0, 1.0))
sigma = numpyro.sample("sigma", dist.LogNormal(-2.0, 1.0))
nu = numpyro.sample("nu", dist.Gamma(10.0, 2.0))
drift = self.time_series(
"drift",
lambda: dist.Normal(0.0, drift_scale),
reparam=LocScaleReparam(0),
)
level = jnp.cumsum(drift, axis=-2) # random-walk level
regression = (weight * covariates).sum(axis=-1, keepdims=True)
prediction = level + bias + regression
self.predict(dist.StudentT(df=nu, loc=0.0, scale=sigma), prediction)
# Synthetic weekly-seasonal series: time at axis -2, one observation dim at -1.
period, t_obs, horizon = 52.0, 156, 26
duration = t_obs + horizon
covariates = fourier_features(duration, period=period, num_terms=3)
t = jnp.arange(duration)[:, None]
truth = jnp.sin(2 * jnp.pi * t / period) + 0.01 * t
data = truth[:t_obs]
key_fit, key_pred = random.split(random.PRNGKey(0))
forecaster = Forecaster(
key_fit,
SeasonalForecaster(),
data,
covariates[:t_obs],
num_steps=1_500,
)
# Draw 100 forecast samples over the held-out horizon, shaped (sample, future, obs).
samples = forecaster(key_pred, data, covariates, num_samples=100)
print("forecast samples:", samples.shape)
print("CRPS:", eval_crps(samples, truth[t_obs:]))
Two APIs: functional core and OOP shim
The package is built around a pure functional core (numpyro_forecast.functional)
and a thin object-oriented shim (numpyro_forecast.forecaster) that ports Pyro's
class-based API. The two are fully interchangeable: both produce the same NumPyro
model callable (covariates, data=None) and consume the same posterior dict of latent
draws, so you can fit with one and forecast with the other.
- Functional core. The train/forecast split is an explicit, immutable
Horizonvalue (derived from the covariate and data shapes) that is threaded into pure primitives. You write a model body(Horizon, covariates) -> Nonethat callstime_series(...)andpredict(...), wrap it withforecasting_model(...), and drive inference with the free functionsfit_svi/fit_mcmc,draw_posterior, andforecast. No global parameter store, explicitPRNGKeythreading. - OOP shim (Pyro-compatible). Subclass
ForecastingModeland implementmodel(self, zero_data, covariates), callingself.time_series(...)andself.predict(...)exactly as in Pyro'spyro.contrib.forecast. TheForecaster(SVI) andHMCForecaster(NUTS) classes carry the horizon as instance state and delegate to the functional core under the hood. This is the API used in the Quickstart above.
The same SeasonalForecaster model, written and run through the functional API:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer.reparam import LocScaleReparam
from numpyro_forecast.evaluate import eval_crps
from numpyro_forecast.functional import (
Horizon,
draw_posterior,
fit_svi,
forecast,
forecasting_model,
predict,
time_series,
)
from numpyro_forecast.util import fourier_features
def seasonal_body(h: Horizon, covariates):
"""Local-level random walk + Fourier seasonality, Student-T noise."""
num_features = covariates.shape[-1]
bias = numpyro.sample("bias", dist.Normal(0.0, 10.0))
weight = numpyro.sample(
"weight", dist.Normal(0.0, 0.1).expand([num_features]).to_event(1)
)
drift_scale = numpyro.sample("drift_scale", dist.LogNormal(-3.0, 1.0))
sigma = numpyro.sample("sigma", dist.LogNormal(-2.0, 1.0))
nu = numpyro.sample("nu", dist.Gamma(10.0, 2.0))
drift = time_series(
h, "drift", lambda: dist.Normal(0.0, drift_scale), reparam=LocScaleReparam(0)
)
level = jnp.cumsum(drift, axis=-2) # random-walk level
regression = (weight * covariates).sum(axis=-1, keepdims=True)
prediction = level + bias + regression
predict(h, dist.StudentT(df=nu, loc=0.0, scale=sigma), prediction)
# Same synthetic series as the Quickstart.
period, t_obs, horizon = 52.0, 156, 26
duration = t_obs + horizon
covariates = fourier_features(duration, period=period, num_terms=3)
t = jnp.arange(duration)[:, None]
truth = jnp.sin(2 * jnp.pi * t / period) + 0.01 * t
data = truth[:t_obs]
model = forecasting_model(seasonal_body)
key_fit, key_post, key_pred = random.split(random.PRNGKey(0), 3)
fit = fit_svi(key_fit, model, data, covariates[:t_obs], num_steps=1_500)
posterior = draw_posterior(key_post, fit, num_samples=100)
samples = forecast(key_pred, model, posterior, data, covariates)
print("forecast samples:", samples.shape)
print("CRPS:", eval_crps(samples, truth[t_obs:]))
Development
This project uses uv for environment management, ruff for linting/formatting, ty for type checking, and prek to run the pre-commit hooks.
uv sync --all-extras # create the environment
prek install # install git hooks
prek run --all-files # lint + format + type check
uv run pytest # run the tests
See CONTRIBUTING.md for the full workflow and guidelines.
License
Apache-2.0.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file numpyro_forecast-0.1.0.tar.gz.
File metadata
- Download URL: numpyro_forecast-0.1.0.tar.gz
- Upload date:
- Size: 12.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd1644b215f248ce7623a7c8cf8ae1e3eaf2d6590e38cdbaa9e96634a12319a4
|
|
| MD5 |
ebe4d8b4cd4e08a2c968180caef20972
|
|
| BLAKE2b-256 |
6234931befb87cd682b99eb2e18b004e6a98f5b566902ed55805e74972e487d9
|
Provenance
The following attestation bundles were made for numpyro_forecast-0.1.0.tar.gz:
Publisher:
release.yml on juanitorduz/numpyro_forecast
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
numpyro_forecast-0.1.0.tar.gz -
Subject digest:
dd1644b215f248ce7623a7c8cf8ae1e3eaf2d6590e38cdbaa9e96634a12319a4 - Sigstore transparency entry: 1929431590
- Sigstore integration time:
-
Permalink:
juanitorduz/numpyro_forecast@f820ee847ebfa73b9679119001bf42ef264c2912 -
Branch / Tag:
refs/tags/0.1.0 - Owner: https://github.com/juanitorduz
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@f820ee847ebfa73b9679119001bf42ef264c2912 -
Trigger Event:
release
-
Statement type:
File details
Details for the file numpyro_forecast-0.1.0-py3-none-any.whl.
File metadata
- Download URL: numpyro_forecast-0.1.0-py3-none-any.whl
- Upload date:
- Size: 26.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
79f6b319cca53a37664d92718beb085b7a5c1fa11f7ec55a25c8b9c80f10db96
|
|
| MD5 |
91bc0b1116e35491f66a1eee4ad4c124
|
|
| BLAKE2b-256 |
e78b7b2ad7cb8ed4ec29c3f4074b5b24abf8e2fc2c28a8581b18a7b1ae00d54a
|
Provenance
The following attestation bundles were made for numpyro_forecast-0.1.0-py3-none-any.whl:
Publisher:
release.yml on juanitorduz/numpyro_forecast
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
numpyro_forecast-0.1.0-py3-none-any.whl -
Subject digest:
79f6b319cca53a37664d92718beb085b7a5c1fa11f7ec55a25c8b9c80f10db96 - Sigstore transparency entry: 1929431673
- Sigstore integration time:
-
Permalink:
juanitorduz/numpyro_forecast@f820ee847ebfa73b9679119001bf42ef264c2912 -
Branch / Tag:
refs/tags/0.1.0 - Owner: https://github.com/juanitorduz
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yml@f820ee847ebfa73b9679119001bf42ef264c2912 -
Trigger Event:
release
-
Statement type: