JAX-native TBATS (forecast::tbats port) — multi-seasonal forecasting with vmap panels on CPU/GPU
Project description
tbats-jax
JAX-native port of R's forecast::tbats. Innovations-form TBATS with
multi-seasonal Fourier harmonics, Box-Cox, missing data, and ARMA errors
— fit via jax.grad + optimistix BFGS on CPU or GPU. Ships with vmap
panel fitting that scales to thousands of series in a single call. At
N=10000, T=1500 on a Colab A100: ~200 s warm fit (after ~200 s
one-time JIT compile). Vs single-core CPU JAX (~33 min extrapolated):
10× warm / 5× cold. Vs R forecast::tbats sequential (~57 min
extrapolated): 17× warm / 8.5× cold.
Experimental: scan-based Levenberg-Marquardt fit (fit_lm, TPU-compatible
but lower convergence quality than the main path) and a NumPyro Bayesian
scaffold (sampler converges poorly on non-trivial priors — see
docs/DEV_NOTES.md).
Install
pip install tbats-jax # core (CPU/GPU)
pip install "tbats-jax[data]" # + pyreadr for fetch_taylor (GPL-3 source, never bundled)
pip install "tbats-jax[bench]" # + Python `tbats` for comparison
pip install "tbats-jax[bayes]" # + numpyro (svi_tbats / bayes_tbats)
Quickstart
import jax
jax.config.update("jax_enable_x64", True) # recommended for fit quality
import numpy as np
from tbats_jax import TBATSSpec, fit_jax, forecast
from tbats_jax.datasets import synthesize_daily
# Daily series with weekly + yearly seasonality (bundled synthetic generator)
y = synthesize_daily(n=730, seed=0)
spec = TBATSSpec(
seasonal=((7.0, 3), (365.25, 5)), # (period, k_harmonics) each
use_trend=True,
use_damping=True,
)
# Fit (~0.5s on CPU, ~0.02s on A100)
result = fit_jax(y, spec)
# 30-day forecast
preds = forecast(y, result.theta, spec, horizon=30)
Batch 1000 series on GPU
import numpy as np
from tbats_jax import TBATSSpec, fit_panel
panel = np.stack([make_series(s) for s in range(1000)]) # shape (1000, T)
spec = TBATSSpec(seasonal=((24.0, 3), (168.0, 5)), use_trend=True, use_damping=True)
thetas_raw, nlls, compile_t, wall_t = fit_panel(panel, spec)
# A100 @ T=1500, N=1000: ~25s wall for all 1000 fits
Auto-select seasonal harmonics
from tbats_jax import auto_fit_jax_cv
r = auto_fit_jax_cv(y, periods=(7.0, 365.25),
use_trend=True, use_damping=True,
val_size=30) # walk-forward CV over last 30 points
# r.spec has the chosen k-vector; r.fit.theta is the final fit
Scope (v0.1.0)
- trend (optional), damping (optional), multi-seasonal harmonics
- Box-Cox transform (optional,
use_box_cox=True) - Missing-data support (NaN values in
yhandled automatically) - innovations state-space form:
y_hat_t = w' x_{t-1},x_t = F x_{t-1} + g e_t - Gaussian negative log-likelihood + hybrid admissibility barrier (exact eigvals on CPU, power-iteration spectral norm on GPU/TPU) + gamma-ridge
- Two optimizers: scipy L-BFGS-B (
fit) and optimistix BFGS (fit_jax) - Uniform panel fit via
vmap(fit_panel) - Heterogeneous panel fit across different specs (
fit_panel_hetero) - Auto k-vector search (
auto_fit_jax) — R-compatible greedy per-period AIC - CPU or GPU (JAX device-agnostic); Colab notebook in
notebooks/ - ARMA errors (
p,qin TBATSSpec) - TPU-viable scan-based LM optimizer (
fit_lm,fit_lm_multistart) - Experimental Bayesian TBATS via NumPyro (HMC)
Layout
tbats_jax/
├── tbats_jax/ # library (pure, device-agnostic)
│ ├── spec.py # TBATSSpec
│ ├── params.py # pack / unpack / init
│ ├── matrices.py # F, g, w builders
│ ├── kernel.py # lax.scan + likelihood
│ ├── forecast.py # h-step point forecast
│ └── fit.py # scipy L-BFGS-B + jax.grad
├── benchmarks/
│ ├── data.py
│ ├── bench_single.py # JAX vs Python `tbats` package
│ └── bench_panel_full.py # vmap vs R sequential
└── tests/test_smoke.py
Setup
cd tbats_jax
python3.12 -m venv .venv
source .venv/bin/activate
pip install -e ".[dev,bench,bayes]"
pyproject.toml is the single source of truth for deps. Extras let
end users install only what they need — [dev] for tests,
[bench] for bench_single.py's comparison against Skorupa's Python
tbats, [bayes] for bayes_tbats / svi_tbats. Add [data] if
you need the Python fetch_taylor() wrapper.
Run
pytest tests/ # smoke tests
python -m benchmarks.bench_single # single-series SSR, 4 backends
python -m benchmarks.bench_oos # out-of-sample MAE/RMSE, 4 backends
python -m benchmarks.bench_panel_full [N] [T] # full fits: vmap vs R sequential
python -m benchmarks.colab_panel_gpu [N] [T] # device-agnostic; GPU on Colab
# Real data (requires R + one-time fetch):
Rscript benchmarks/fetch_real_data.R data
python -m benchmarks.bench_real # forecast::taylor OOS
The R comparison is auto-detected. If Rscript is on PATH and the
forecast package is installed, bench_single includes it; otherwise it
cleanly skips. To enable:
brew install r
Rscript -e 'install.packages("forecast", repos="https://cloud.r-project.org")'
Colab
Same code. Switch runtime → GPU, install requirements, run the same benchmarks. No source changes required. Enable float64 in the first cell:
import jax; jax.config.update("jax_enable_x64", True)
Results on Apple Silicon (M-series CPU, float64)
Series: n=1500, two seasonal periods (24, 168), k=(3,5), trend+damping.
Four-way comparison: JAX vs Python tbats (pure numpy port) vs R
forecast::tbats (the original, C++ kernel via Rcpp/Armadillo).
Single-series in-sample fit (SSR) — bench_single.py
Synthetic series, n=1500, periods=(24, 168), k=(3, 5), trend+damping.
| Implementation | Wall | SSR | Notes |
|---|---|---|---|
JAX fit_jax (optimistix BFGS) |
0.72 s | 361.8 | on-graph, + gamma-ridge + log-hinge |
R forecast:::fitSpecificTBATS (same k) |
0.35 s | 374.9 | Original C++ kernel, Nelder-Mead |
JAX fit (scipy L-BFGS-B) |
0.80 s | 379.9 | scipy path retained for comparison |
Python tbats (same k pinned) |
1.19 s | 375.0 | Nelder-Mead over pure numpy |
R forecast::tbats (auto-search) |
1.57 s | 369.5 | chose k=[6,6] |
Python tbats (auto-search) |
11.5 s | 373.0 | chose k=[6,2] |
JAX finds better in-sample SSR than R (361.8 vs 374.9 at matched structure) at ~2× R's wall time. The C++ baseline is hard to beat on a single fit, but quality is now unambiguously at parity.
Panel fit — bench_panel_full.py
32 synthetic series × 1500 obs each. fit_panel JIT-compiles one fused
kernel across all series; R runs sequential Rscript calls.
| Backend | Total | Per-series | Mean SSR |
|---|---|---|---|
JAX fit_panel (vmap) |
7.3 s | 227 ms | 374.8 |
| JAX loop (no vmap) | 59.3 s | 1852 ms | 375.2 |
| R forecast sequential | 23.7 s | 740 ms | 382.8 |
JAX vmap is 3.3× faster than R sequential AND finds lower mean SSR (374.8 vs 382.8, 2.1% better). This is the headline result: parity quality at scaled throughput. At N=100 per-series time stays at ~213 ms as JIT overhead amortizes further.
Out-of-sample forecast accuracy — bench_oos.py
Train on first 1350 points, forecast last 150.
| Implementation | Wall | Train SSR | Test MAE | Test RMSE |
|---|---|---|---|---|
R forecast (fixed k) |
0.32 s | 334 | 0.4273 | 0.5250 |
R forecast (auto) |
1.51 s | 330 | 0.4224 | 0.5193 |
Python tbats (auto) |
10.5 s | — | 0.4280 | 0.5260 |
JAX fit_jax |
0.47 s | 335 | 0.4447 | 0.5384 |
Test MAE gap vs R: 4.1%, down from ~14% before the v0.0.3 regularization
fixes. JAX now also beats Python tbats on wall time by 22×.
Real dataset — bench_real.py (forecast::taylor)
Half-hourly UK electricity demand, n=4032, periods=(48, 336), k=(3, 5). One-week hold-out (336 points). The canonical TBATS benchmark.
| Backend | Wall | Train SSR | Test MAE | Test RMSE |
|---|---|---|---|---|
| R forecast fixed | 0.33 s | 6.99e8 | 1030 | 1332 |
JAX fit_jax |
4.16 s | 6.96e8 | 1041 | 1343 |
| R forecast auto (k=[12,5]) | 10.03 s | 3.01e8 | 1273 | 1499 |
JAX test MAE is within 1.1% of R's fixed fit — essentially parity on the
metric that matters. The wall-time penalty is larger than on short series
(4.2 s vs 0.3 s) because n=4032 makes the scan longer; this scales with
panel width, not vmap'd series count.
Accelerator comparison — fit_panel
Two workloads tested: a daily panel (T=730, weekly+yearly seasonality) for a baseline sanity check, and an hourly panel (T=1500, periods=24+168) which is where the GPU story actually matters for long-series production data.
Daily panel (T=730, k=(3,5))
| Backend | Compile | Warm wall | Per-series | vs CPU | Notes |
|---|---|---|---|---|---|
| Apple Silicon M-series CPU | 5.4 s | 4.7 s (N=100) | 45 ms | 1.0× | eigvals path, exact rho |
| CUDA T4 16 GB | 33 s | 24 s (N=1000) | 24 ms | 1.85× | power-iter path |
| CUDA T4 16 GB (N=2000) | 62 s | 56 s | 28 ms | 1.60× | per-series creeps as HBM loads up |
| TPU v5e-1 | 1459 s | 1429 s (N=1000) | 1429 ms | 0.03× (slower!) | see below |
Hourly panel (T=1500, k=(3,5)) — the headline
Per-series CPU work is 4× heavier here (196 ms vs 45 ms). T4 hit its bandwidth wall and lost to CPU at N=5000. A100 had the headroom to scale:
| Backend | N=1000 | N=5000 | N=10000 | N=20000 |
|---|---|---|---|---|
| CPU (Apple Silicon M-series) | 196 ms | 196 ms* | 196 ms* | 196 ms* |
| CUDA T4 16 GB | 104 ms (1.88×) | 214 ms (0.92× — lost) | — | — |
| CUDA A100 40 GB (warm ms/ser) | 53 ms | 20 ms | 20 ms | 30 ms |
| A100 compile (one-time JIT) | 59 s | 106 s | 206 s | 607 s |
| A100 warm wall | 52 s | 99 s | 200 s | 597 s |
| A100 warm vs CPU | 3.8× | 9.9× | 9.8× | 6.6× |
A100 warm vs R forecast::tbats seq. |
6.6× | 17.3× | 17.1× | 11.5× |
| A100 cold (compile + warm) vs CPU | 1.8× | 4.7× | 4.8× | 3.3× |
A100 cold vs R forecast::tbats seq. |
3.1× | 8.2× | 8.5× | 5.8× |
*CPU at large N is linear-extrapolation from the measured 196 ms/series at N=500.
A100 fits 10,000 independent TBATS models in ~200 s warm (after a
~200 s one-time JIT compile). Total first-run wall: ~406 s. Sequential
R forecast::tbats on the same workload would take ~57 minutes. CPU
(single core) takes ~33 minutes.
The warm numbers assume a long-lived process that reuses the compiled panel function (production batch retrain, serving). The cold numbers reflect one-shot scripts. Both are honest; pick whichever matches your workflow.
Sweet spot is N=5000–10000 where per-series time drops 2.7× from the N=1000 launch-bound regime to the steady-state ~20 ms. Beyond N≈15000 compile time starts dominating and the per-series curve turns back up (HBM pressure); recommended practice is to chunk larger panels into 10k-series buckets.
Why A100 scales where T4 didn't:
- 5× HBM bandwidth (1.6 TB/s vs 320 GB/s) — T4's bandwidth wall hit at N=5000 simply doesn't appear until N>20000 on A100
- Enough SMs to saturate on the vmap axis — at N=1000 both T4 and A100 are launch-bound, but A100 scales the lanes per kernel; T4 stalls
- 39× FP64 throughput — mostly unused in a launch-bound kernel, but helps at the margins
Why TPU v5e-1 loses so badly
The compile alone is 24 min and the warm fit is another 24 min on a workload GPU finishes in under a minute. Three structural reasons:
optimistix.BFGSuseslax.while_loop. TPU XLA compilation of unbounded-iteration loops is expensive. Every step size, line search, and convergence check is a symbolic branch the compiler has to lower into TPU bundles. GPU XLA handles this gracefully; TPU XLA does not.- v5e-1 is a single "Lite" chip, optimized for inference matmul throughput (bf16/int8). Our per-step FLOPs are tiny (~300 on an 18×18 matmul) — well below the granularity the TPU wants for good utilization.
- Scan serialization over T=730 hits the same launch-coordination issue on TPU that it hits on GPU — but without GPU's faster step dispatch.
A fixed-iteration optimizer (e.g., a hand-written BFGS expressed as a
lax.scan with bounded steps, instead of while_loop) would likely fix
this and make v5e-1 competitive. That's a real rewrite, not a tuning
change — we'd give up adaptive convergence tolerance in exchange for
TPU-friendly compilation.
Practical recommendations
- For development and small panels (N ≤ 200): use CPU.
eigvalsgives tight barriers, compile is fast, per-series throughput is fine. - For moderate panels (N = 500–2000): T4 (Colab free tier) is enough. Expect ~1.7-2× over CPU, compile ~30 s.
- For large hourly panels (N = 1000–10000, T ≈ 1500): use A100 (Colab Pro). Warm runs scale to ~10× over CPU, ~17× over R at N=5000–10000. Total first-run wall ≈ compile + warm ≈ 2× warm; run the fit repeatedly (or at least twice) to amortize the JIT.
- For very large panels (N > 20000): compile cost and HBM load start to degrade per-series time on A100 too. Either chunk the panel, or use H100 if available (untested here).
- Don't use v5e-1 TPU without a fixed-iteration optimizer rewrite.
Its design assumptions don't match our
while_loop-based kernel.
How the v0.0.2 → v0.0.3 regressions were fixed
The earlier OOS gap (14% synthetic, 2× on taylor) traced to three things,
each discovered by reading forecast::tbats source and diffing fitted
parameters backend-to-backend:
-
Over-constrained transforms. v0.0.2 sigmoid-mapped alpha to [0,1], beta to [0,1], gammas to [-0.5, 0.5]. The R reference (
checkAdmissibility.R) in fact only boundsphi ∈ [0.8, 1.0]; alpha, beta, and gammas are unbounded. R's fittedalphaon taylor is 1.68 (!) andbetais −0.24. Fix: identity transforms for everything except phi. -
Wrong starting values. v0.0.2 used
gammas=0.001, phi=0.98. R usesgammas=0, phi=0.999(fitTBATS.R lines 160-179). Matching gave another ~9% MAE reduction on taylor. -
Missing implicit regularization. R's
makeParscale()setsparscale=1e-5for gammas, which makes Nelder-Mead take microscopic steps in gamma space — an accidental regularizer that keeps ||gamma|| tiny (~1e-3 in R vs ~2.9 in our v0.0.2). We match this explicitly with an L2 ridge on gammas,gamma_ridge=1e6by default. This closed most of the remaining taylor gap.
What's honest to claim today
Works well:
- Single-series fit quality better than R's Nelder-Mead on synthetic data.
- Panel
vmapis 3.3× faster than R sequential and 2% better mean SSR. - Out-of-sample: within 4% of R on synthetic, within 1% on real taylor.
- Faithful mirror of forecast::tbats bounds, init, and implicit regularization.
Remaining gaps (v0.0.6):
- Single-series wall time ~2× R (compiled C++ vs JIT'd JAX). Won't close without dropping past JAX, not worth it.
- Box-Cox is implemented (
use_box_cox=True) but interacts withgamma_ridgedefaults: turn ridge down to ~1e3-1e4 when Box-Cox is on until auto-scaling is added. - No ARMA errors.
- Auto k-search is implemented but AIC selection doesn't always predict OOS performance on specific datasets — same issue R has.
jax-metalnot a practical backend; CPU + CUDA only.- TPU v5e-1 not viable with current BFGS-while-loop design. Fixed- iteration optimizer required for TPU compile to be reasonable.
The structural value proposition — panel vmap, gradients for Bayesian
TBATS, differentiable TBATS as a layer — is now backed by quality parity
with the R reference, not just "works in JAX."
Panel micro-bench (50 series × 1000 obs):
| Mode | Total wall | Per-series |
|---|---|---|
| Looped JAX fits | 29.5 s | 590 ms |
| Batched likelihood (vmap, no fit) | 1.4 ms | 0.03 ms |
The batched likelihood result is the headline: once the optimizer moves inside JAX (jaxopt/optimistix), the full fit becomes one fused vmapped call — the path to real panel scaling.
Developer notes
Current state, known limitations, and the ranked next-steps queue live in docs/DEV_NOTES.md. That's the resumption anchor for any contributor picking up work.
Contributing
Issues and PRs welcome at the primary repo on Codeberg. The GitHub mirror is read-only and syncs automatically — don't open PRs against it.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file tbats_jax-0.1.1.tar.gz.
File metadata
- Download URL: tbats_jax-0.1.1.tar.gz
- Upload date:
- Size: 55.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0d8c78365d244b55bcdcbb456cc1e59edbbe2d0bc966899966cbf382ba7048d2
|
|
| MD5 |
9fc026de381fd4010c6f7fa8a07eb372
|
|
| BLAKE2b-256 |
137996d0bd80d71ed57969b46cd1eed2fe3c78d0877105ba150b8ce0c38e1781
|
File details
Details for the file tbats_jax-0.1.1-py3-none-any.whl.
File metadata
- Download URL: tbats_jax-0.1.1-py3-none-any.whl
- Upload date:
- Size: 47.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8b44dfe6a9c3d30c38925dd0a6e2acc916c4f04f68d3699093c2ebd401cde249
|
|
| MD5 |
dd9807136dd8d30114fc52f6c82414f0
|
|
| BLAKE2b-256 |
451c7faa30526fe8fe8a52b5b79b3e89dc6def18c8423376b68c9aa96bf9ea43
|