Skip to main content

JAX-native TBATS (forecast::tbats port) — multi-seasonal forecasting with vmap panels on CPU/GPU

Project description

tbats-jax

PyPI Python License Source — Codeberg

JAX-native port of R's forecast::tbats. Innovations-form TBATS with multi-seasonal Fourier harmonics, Box-Cox, missing data, and ARMA errors — fit via jax.grad + optimistix BFGS on CPU or GPU. Ships with vmap panel fitting that scales to thousands of series in a single call. At N=10000, T=1500 on a Colab A100: ~200 s warm fit (after ~200 s one-time JIT compile). Vs single-core CPU JAX (~33 min extrapolated): 10× warm / 5× cold. Vs R forecast::tbats sequential (~57 min extrapolated): 17× warm / 8.5× cold.

Experimental: scan-based Levenberg-Marquardt fit (fit_lm, TPU-compatible but lower convergence quality than the main path) and a NumPyro Bayesian scaffold (sampler converges poorly on non-trivial priors — see docs/DEV_NOTES.md).

Install

pip install tbats-jax                  # core (CPU/GPU)
pip install "tbats-jax[data]"          # + pyreadr for fetch_taylor (GPL-3 source, never bundled)
pip install "tbats-jax[bench]"         # + Python `tbats` for comparison
pip install "tbats-jax[bayes]"         # + numpyro (svi_tbats / bayes_tbats)

Quickstart

import jax
jax.config.update("jax_enable_x64", True)   # recommended for fit quality

import numpy as np
from tbats_jax import TBATSSpec, fit_jax, forecast
from tbats_jax.datasets import synthesize_daily

# Daily series with weekly + yearly seasonality (bundled synthetic generator)
y = synthesize_daily(n=730, seed=0)

spec = TBATSSpec(
    seasonal=((7.0, 3), (365.25, 5)),       # (period, k_harmonics) each
    use_trend=True,
    use_damping=True,
)

# Fit (~0.5s on CPU, ~0.02s on A100)
result = fit_jax(y, spec)

# 30-day forecast
preds = forecast(y, result.theta, spec, horizon=30)

Batch 1000 series on GPU

import numpy as np
from tbats_jax import TBATSSpec, fit_panel

panel = np.stack([make_series(s) for s in range(1000)])  # shape (1000, T)
spec  = TBATSSpec(seasonal=((24.0, 3), (168.0, 5)), use_trend=True, use_damping=True)

thetas_raw, nlls, compile_t, wall_t = fit_panel(panel, spec)
# A100 @ T=1500, N=1000: ~25s wall for all 1000 fits

Auto-select seasonal harmonics

from tbats_jax import auto_fit_jax_cv

r = auto_fit_jax_cv(y, periods=(7.0, 365.25),
                   use_trend=True, use_damping=True,
                   val_size=30)  # walk-forward CV over last 30 points
# r.spec has the chosen k-vector; r.fit.theta is the final fit

Scope (v0.1.0)

  • trend (optional), damping (optional), multi-seasonal harmonics
  • Box-Cox transform (optional, use_box_cox=True)
  • Missing-data support (NaN values in y handled automatically)
  • innovations state-space form: y_hat_t = w' x_{t-1}, x_t = F x_{t-1} + g e_t
  • Gaussian negative log-likelihood + hybrid admissibility barrier (exact eigvals on CPU, power-iteration spectral norm on GPU/TPU) + gamma-ridge
  • Two optimizers: scipy L-BFGS-B (fit) and optimistix BFGS (fit_jax)
  • Uniform panel fit via vmap (fit_panel)
  • Heterogeneous panel fit across different specs (fit_panel_hetero)
  • Auto k-vector search (auto_fit_jax) — R-compatible greedy per-period AIC
  • CPU or GPU (JAX device-agnostic); Colab notebook in notebooks/
  • ARMA errors (p, q in TBATSSpec)
  • TPU-viable scan-based LM optimizer (fit_lm, fit_lm_multistart)
  • Experimental Bayesian TBATS via NumPyro (HMC)

Layout

tbats_jax/
├── tbats_jax/          # library (pure, device-agnostic)
│   ├── spec.py         # TBATSSpec
│   ├── params.py       # pack / unpack / init
│   ├── matrices.py     # F, g, w builders
│   ├── kernel.py       # lax.scan + likelihood
│   ├── forecast.py     # h-step point forecast
│   └── fit.py          # scipy L-BFGS-B + jax.grad
├── benchmarks/
│   ├── data.py
│   ├── bench_single.py # JAX vs Python `tbats` package
│   └── bench_panel_full.py  # vmap vs R sequential
└── tests/test_smoke.py

Setup

cd tbats_jax
python3.12 -m venv .venv
source .venv/bin/activate
pip install -e ".[dev,bench,bayes]"

pyproject.toml is the single source of truth for deps. Extras let end users install only what they need — [dev] for tests, [bench] for bench_single.py's comparison against Skorupa's Python tbats, [bayes] for bayes_tbats / svi_tbats. Add [data] if you need the Python fetch_taylor() wrapper.

Run

pytest tests/                        # smoke tests
python -m benchmarks.bench_single    # single-series SSR, 4 backends
python -m benchmarks.bench_oos       # out-of-sample MAE/RMSE, 4 backends
python -m benchmarks.bench_panel_full [N] [T]  # full fits: vmap vs R sequential
python -m benchmarks.colab_panel_gpu [N] [T]   # device-agnostic; GPU on Colab
# Real data (requires R + one-time fetch):
Rscript benchmarks/fetch_real_data.R data
python -m benchmarks.bench_real      # forecast::taylor OOS

The R comparison is auto-detected. If Rscript is on PATH and the forecast package is installed, bench_single includes it; otherwise it cleanly skips. To enable:

brew install r
Rscript -e 'install.packages("forecast", repos="https://cloud.r-project.org")'

Colab

Same code. Switch runtime → GPU, install requirements, run the same benchmarks. No source changes required. Enable float64 in the first cell:

import jax; jax.config.update("jax_enable_x64", True)

Results on Apple Silicon (M-series CPU, float64)

Series: n=1500, two seasonal periods (24, 168), k=(3,5), trend+damping. Four-way comparison: JAX vs Python tbats (pure numpy port) vs R forecast::tbats (the original, C++ kernel via Rcpp/Armadillo).

Single-series in-sample fit (SSR) — bench_single.py

Synthetic series, n=1500, periods=(24, 168), k=(3, 5), trend+damping.

Implementation Wall SSR Notes
JAX fit_jax (optimistix BFGS) 0.72 s 361.8 on-graph, + gamma-ridge + log-hinge
R forecast:::fitSpecificTBATS (same k) 0.35 s 374.9 Original C++ kernel, Nelder-Mead
JAX fit (scipy L-BFGS-B) 0.80 s 379.9 scipy path retained for comparison
Python tbats (same k pinned) 1.19 s 375.0 Nelder-Mead over pure numpy
R forecast::tbats (auto-search) 1.57 s 369.5 chose k=[6,6]
Python tbats (auto-search) 11.5 s 373.0 chose k=[6,2]

JAX finds better in-sample SSR than R (361.8 vs 374.9 at matched structure) at ~2× R's wall time. The C++ baseline is hard to beat on a single fit, but quality is now unambiguously at parity.

Panel fit — bench_panel_full.py

32 synthetic series × 1500 obs each. fit_panel JIT-compiles one fused kernel across all series; R runs sequential Rscript calls.

Backend Total Per-series Mean SSR
JAX fit_panel (vmap) 7.3 s 227 ms 374.8
JAX loop (no vmap) 59.3 s 1852 ms 375.2
R forecast sequential 23.7 s 740 ms 382.8

JAX vmap is 3.3× faster than R sequential AND finds lower mean SSR (374.8 vs 382.8, 2.1% better). This is the headline result: parity quality at scaled throughput. At N=100 per-series time stays at ~213 ms as JIT overhead amortizes further.

Out-of-sample forecast accuracy — bench_oos.py

Train on first 1350 points, forecast last 150.

Implementation Wall Train SSR Test MAE Test RMSE
R forecast (fixed k) 0.32 s 334 0.4273 0.5250
R forecast (auto) 1.51 s 330 0.4224 0.5193
Python tbats (auto) 10.5 s 0.4280 0.5260
JAX fit_jax 0.47 s 335 0.4447 0.5384

Test MAE gap vs R: 4.1%, down from ~14% before the v0.0.3 regularization fixes. JAX now also beats Python tbats on wall time by 22×.

Real dataset — bench_real.py (forecast::taylor)

Half-hourly UK electricity demand, n=4032, periods=(48, 336), k=(3, 5). One-week hold-out (336 points). The canonical TBATS benchmark.

Backend Wall Train SSR Test MAE Test RMSE
R forecast fixed 0.33 s 6.99e8 1030 1332
JAX fit_jax 4.16 s 6.96e8 1041 1343
R forecast auto (k=[12,5]) 10.03 s 3.01e8 1273 1499

JAX test MAE is within 1.1% of R's fixed fit — essentially parity on the metric that matters. The wall-time penalty is larger than on short series (4.2 s vs 0.3 s) because n=4032 makes the scan longer; this scales with panel width, not vmap'd series count.

Accelerator comparison — fit_panel

Two workloads tested: a daily panel (T=730, weekly+yearly seasonality) for a baseline sanity check, and an hourly panel (T=1500, periods=24+168) which is where the GPU story actually matters for long-series production data.

Daily panel (T=730, k=(3,5))

Backend Compile Warm wall Per-series vs CPU Notes
Apple Silicon M-series CPU 5.4 s 4.7 s (N=100) 45 ms 1.0× eigvals path, exact rho
CUDA T4 16 GB 33 s 24 s (N=1000) 24 ms 1.85× power-iter path
CUDA T4 16 GB (N=2000) 62 s 56 s 28 ms 1.60× per-series creeps as HBM loads up
TPU v5e-1 1459 s 1429 s (N=1000) 1429 ms 0.03× (slower!) see below

Hourly panel (T=1500, k=(3,5)) — the headline

Per-series CPU work is 4× heavier here (196 ms vs 45 ms). T4 hit its bandwidth wall and lost to CPU at N=5000. A100 had the headroom to scale:

Backend N=1000 N=5000 N=10000 N=20000
CPU (Apple Silicon M-series) 196 ms 196 ms* 196 ms* 196 ms*
CUDA T4 16 GB 104 ms (1.88×) 214 ms (0.92× — lost)
CUDA A100 40 GB (warm ms/ser) 53 ms 20 ms 20 ms 30 ms
A100 compile (one-time JIT) 59 s 106 s 206 s 607 s
A100 warm wall 52 s 99 s 200 s 597 s
A100 warm vs CPU 3.8× 9.9× 9.8× 6.6×
A100 warm vs R forecast::tbats seq. 6.6× 17.3× 17.1× 11.5×
A100 cold (compile + warm) vs CPU 1.8× 4.7× 4.8× 3.3×
A100 cold vs R forecast::tbats seq. 3.1× 8.2× 8.5× 5.8×

*CPU at large N is linear-extrapolation from the measured 196 ms/series at N=500.

A100 fits 10,000 independent TBATS models in ~200 s warm (after a ~200 s one-time JIT compile). Total first-run wall: ~406 s. Sequential R forecast::tbats on the same workload would take ~57 minutes. CPU (single core) takes ~33 minutes.

The warm numbers assume a long-lived process that reuses the compiled panel function (production batch retrain, serving). The cold numbers reflect one-shot scripts. Both are honest; pick whichever matches your workflow.

Sweet spot is N=5000–10000 where per-series time drops 2.7× from the N=1000 launch-bound regime to the steady-state ~20 ms. Beyond N≈15000 compile time starts dominating and the per-series curve turns back up (HBM pressure); recommended practice is to chunk larger panels into 10k-series buckets.

Why A100 scales where T4 didn't:

  • 5× HBM bandwidth (1.6 TB/s vs 320 GB/s) — T4's bandwidth wall hit at N=5000 simply doesn't appear until N>20000 on A100
  • Enough SMs to saturate on the vmap axis — at N=1000 both T4 and A100 are launch-bound, but A100 scales the lanes per kernel; T4 stalls
  • 39× FP64 throughput — mostly unused in a launch-bound kernel, but helps at the margins

Why TPU v5e-1 loses so badly

The compile alone is 24 min and the warm fit is another 24 min on a workload GPU finishes in under a minute. Three structural reasons:

  1. optimistix.BFGS uses lax.while_loop. TPU XLA compilation of unbounded-iteration loops is expensive. Every step size, line search, and convergence check is a symbolic branch the compiler has to lower into TPU bundles. GPU XLA handles this gracefully; TPU XLA does not.
  2. v5e-1 is a single "Lite" chip, optimized for inference matmul throughput (bf16/int8). Our per-step FLOPs are tiny (~300 on an 18×18 matmul) — well below the granularity the TPU wants for good utilization.
  3. Scan serialization over T=730 hits the same launch-coordination issue on TPU that it hits on GPU — but without GPU's faster step dispatch.

A fixed-iteration optimizer (e.g., a hand-written BFGS expressed as a lax.scan with bounded steps, instead of while_loop) would likely fix this and make v5e-1 competitive. That's a real rewrite, not a tuning change — we'd give up adaptive convergence tolerance in exchange for TPU-friendly compilation.

Practical recommendations

  • For development and small panels (N ≤ 200): use CPU. eigvals gives tight barriers, compile is fast, per-series throughput is fine.
  • For moderate panels (N = 500–2000): T4 (Colab free tier) is enough. Expect ~1.7-2× over CPU, compile ~30 s.
  • For large hourly panels (N = 1000–10000, T ≈ 1500): use A100 (Colab Pro). Warm runs scale to ~10× over CPU, ~17× over R at N=5000–10000. Total first-run wall ≈ compile + warm ≈ 2× warm; run the fit repeatedly (or at least twice) to amortize the JIT.
  • For very large panels (N > 20000): compile cost and HBM load start to degrade per-series time on A100 too. Either chunk the panel, or use H100 if available (untested here).
  • Don't use v5e-1 TPU without a fixed-iteration optimizer rewrite. Its design assumptions don't match our while_loop-based kernel.

How the v0.0.2 → v0.0.3 regressions were fixed

The earlier OOS gap (14% synthetic, 2× on taylor) traced to three things, each discovered by reading forecast::tbats source and diffing fitted parameters backend-to-backend:

  1. Over-constrained transforms. v0.0.2 sigmoid-mapped alpha to [0,1], beta to [0,1], gammas to [-0.5, 0.5]. The R reference (checkAdmissibility.R) in fact only bounds phi ∈ [0.8, 1.0]; alpha, beta, and gammas are unbounded. R's fitted alpha on taylor is 1.68 (!) and beta is −0.24. Fix: identity transforms for everything except phi.

  2. Wrong starting values. v0.0.2 used gammas=0.001, phi=0.98. R uses gammas=0, phi=0.999 (fitTBATS.R lines 160-179). Matching gave another ~9% MAE reduction on taylor.

  3. Missing implicit regularization. R's makeParscale() sets parscale=1e-5 for gammas, which makes Nelder-Mead take microscopic steps in gamma space — an accidental regularizer that keeps ||gamma|| tiny (~1e-3 in R vs ~2.9 in our v0.0.2). We match this explicitly with an L2 ridge on gammas, gamma_ridge=1e6 by default. This closed most of the remaining taylor gap.

What's honest to claim today

Works well:

  • Single-series fit quality better than R's Nelder-Mead on synthetic data.
  • Panel vmap is 3.3× faster than R sequential and 2% better mean SSR.
  • Out-of-sample: within 4% of R on synthetic, within 1% on real taylor.
  • Faithful mirror of forecast::tbats bounds, init, and implicit regularization.

Remaining gaps (v0.0.6):

  • Single-series wall time ~2× R (compiled C++ vs JIT'd JAX). Won't close without dropping past JAX, not worth it.
  • Box-Cox is implemented (use_box_cox=True) but interacts with gamma_ridge defaults: turn ridge down to ~1e3-1e4 when Box-Cox is on until auto-scaling is added.
  • No ARMA errors.
  • Auto k-search is implemented but AIC selection doesn't always predict OOS performance on specific datasets — same issue R has.
  • jax-metal not a practical backend; CPU + CUDA only.
  • TPU v5e-1 not viable with current BFGS-while-loop design. Fixed- iteration optimizer required for TPU compile to be reasonable.

The structural value proposition — panel vmap, gradients for Bayesian TBATS, differentiable TBATS as a layer — is now backed by quality parity with the R reference, not just "works in JAX."

Panel micro-bench (50 series × 1000 obs):

Mode Total wall Per-series
Looped JAX fits 29.5 s 590 ms
Batched likelihood (vmap, no fit) 1.4 ms 0.03 ms

The batched likelihood result is the headline: once the optimizer moves inside JAX (jaxopt/optimistix), the full fit becomes one fused vmapped call — the path to real panel scaling.

Developer notes

Current state, known limitations, and the ranked next-steps queue live in docs/DEV_NOTES.md. That's the resumption anchor for any contributor picking up work.

Contributing

Issues and PRs welcome at the primary repo on Codeberg. The GitHub mirror is read-only and syncs automatically — don't open PRs against it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tbats_jax-0.1.1.tar.gz (55.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tbats_jax-0.1.1-py3-none-any.whl (47.9 kB view details)

Uploaded Python 3

File details

Details for the file tbats_jax-0.1.1.tar.gz.

File metadata

  • Download URL: tbats_jax-0.1.1.tar.gz
  • Upload date:
  • Size: 55.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for tbats_jax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 0d8c78365d244b55bcdcbb456cc1e59edbbe2d0bc966899966cbf382ba7048d2
MD5 9fc026de381fd4010c6f7fa8a07eb372
BLAKE2b-256 137996d0bd80d71ed57969b46cd1eed2fe3c78d0877105ba150b8ce0c38e1781

See more details on using hashes here.

File details

Details for the file tbats_jax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: tbats_jax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 47.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.13

File hashes

Hashes for tbats_jax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8b44dfe6a9c3d30c38925dd0a6e2acc916c4f04f68d3699093c2ebd401cde249
MD5 dd9807136dd8d30114fc52f6c82414f0
BLAKE2b-256 451c7faa30526fe8fe8a52b5b79b3e89dc6def18c8423376b68c9aa96bf9ea43

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page