CLI tool for running JAX training on Google Cloud Spot TPUs with automatic preemption recovery and checkpoint resumption
Project description
SpotJAX
CLI tool for running JAX training on Google Cloud Spot TPUs with automatic preemption recovery. Provisions TPUs, uploads code, runs training, and seamlessly retries when Spot instances get preempted.
Installation
pip install spotax
Requires Python 3.10+ and gcloud CLI.
# Verify prerequisites
spotax setup
# Auto-fix issues (SSH keys, OS Login)
spotax setup --fix
Quick Start
spotax run train.py --tpu v5litepod-1 --zone us-central1-a
SpotJAX will:
- Create a GCS bucket for checkpoints
- Provision a Spot TPU
- SSH into all nodes and upload your code via rsync
- Run
spotax_setup.shif present (custom pre-install steps) - Install
requirements.txtdependencies - Run your script with checkpoint/distributed env vars injected
- On preemption: clean up, provision a new TPU, and resume from last checkpoint
Project Structure
your-project/
train.py # Your training script
data.py # Data loading (optional)
spotax_utils.py # Checkpoint & distributed utilities (copy from examples/)
requirements.txt # Dependencies installed on TPU VMs
spotax_setup.sh # Pre-install script (optional)
spotax_utils.py
SpotJAX handles infrastructure recovery automatically — on preemption it provisions a new TPU and reruns your script. But without checkpointing, your training would restart from step 0 every time. spotax_utils.py bridges this gap: it saves model state to GCS and restores it on retry, so training resumes from where it left off.
Copy this file from examples/ into your project. It provides checkpoint management and distributed training setup with no runtime dependency on the spotax package.
from spotax_utils import CheckpointManager, get_config, setup_distributed
config = get_config()
setup_distributed(config) # Required for multi-node (v4-16+), no-op for single-node
ckpt = CheckpointManager(config.checkpoint_dir, save_interval_steps=1000)
state, start_step = ckpt.restore_or_init(initial_state)
for step in range(start_step, max_steps):
state = train_step(state, batch)
ckpt.save(step, state)
if ckpt.reached_preemption(step):
break # Orbax already saved checkpoint, orchestrator will retry
ckpt.close()
How checkpointing works:
- SpotJAX enables GCP's autocheckpoint. On preemption, GCP sends SIGTERM to the VM.
- Orbax catches SIGTERM and saves a checkpoint automatically, even outside
save_interval_steps. reached_preemption()detects this across all hosts and returnsTrueso your script exits cleanly.- The orchestrator then provisions a new TPU and reruns.
restore_or_init()picks up from the last checkpoint.
requirements.txt
Standard pip requirements. SpotJAX installs them on each TPU VM using uv with the JAX TPU releases index. Include jax[tpu] and any other dependencies your script needs:
jax[tpu]
flax
optax
orbax-checkpoint
grain
spotax_setup.sh (optional)
Runs before requirements.txt installation. Use it for things pip can't handle: system packages, building from source, patching libraries. The venv is already activated when this runs.
Environment Variables
SpotJAX injects these into your training script (read them via get_config()):
| Variable | Description |
|---|---|
SPOT_CHECKPOINT_DIR |
GCS path for checkpoints (gs://bucket/job-id/ckpt) |
SPOT_LOG_DIR |
GCS path for logs |
SPOT_JOB_ID |
Unique job identifier |
SPOT_IS_RESTART |
"true" if resuming after preemption |
Multi-node only (automatically set for v4-16+, v5litepod-4+, etc.):
| Variable | Description |
|---|---|
SPOT_WORKER_ID |
Node index (0 to N-1) |
SPOT_NUM_WORKERS |
Total node count |
JAX_COORDINATOR_ADDRESS |
Internal IP:port for JAX distributed |
CLI Reference
spotax run <script> [options]
| Option | Default | Description |
|---|---|---|
--tpu, -t |
v5litepod-1 |
TPU type |
--zone, -z |
us-central1-a |
GCP zone |
--project, -p |
auto-detect | GCP project ID |
--bucket, -b |
spotax-{project} |
GCS bucket for checkpoints |
--name, -n |
timestamp | Job name |
--max-retries |
5 |
Max restart attempts |
--stream-worker, -w |
0 |
Worker index to stream logs from |
--code-dir, -c |
script's parent dir | Directory to upload |
Examples
- ImageNet EfficientNet - Train EfficientNet-B2 on ImageNet-1K with ArrayRecord data pipeline
- Simple Math SFT - Fine-tune Qwen3 on GSM8K math problems (under development)
Requirements
- Python 3.10+
- GCP project with TPU API enabled and Spot TPU quota
- gcloud CLI authenticated with Application Default Credentials
License
Apache 2.0
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spotax-0.1.0.tar.gz.
File metadata
- Download URL: spotax-0.1.0.tar.gz
- Upload date:
- Size: 63.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b3cc3ac4e587333035f9ab9a102c1cd69be796d9ae67c2be25b91485135aa078
|
|
| MD5 |
84580c2050efaaa2672ff8fcc760095d
|
|
| BLAKE2b-256 |
f111ae83a653c0355028fa7631e74e2d6e93e99e1b10ee2d02475da1a90ec974
|
File details
Details for the file spotax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: spotax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 37.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
fb517e5e32bb9cb7e11491e4f553bc1870d7a503fae73ec7bacf6db1906b8f01
|
|
| MD5 |
b8f453b6c0e367752b2ca52967517629
|
|
| BLAKE2b-256 |
440066f7734004f63b5fe4b58735b8ada1d583b72aa8bc9d394bfeae2abfba1d
|