Skip to main content

CLI tool for running JAX training on Google Cloud Spot TPUs with automatic preemption recovery and checkpoint resumption

Project description

SpotJAX

CLI tool for running JAX training on Google Cloud Spot TPUs with automatic preemption recovery. Provisions TPUs, uploads code, runs training, and seamlessly retries when Spot instances get preempted.

Installation

pip install spotax

Requires Python 3.10+ and gcloud CLI.

# Verify prerequisites
spotax setup

# Auto-fix issues (SSH keys, OS Login)
spotax setup --fix

Quick Start

spotax run train.py --tpu v5litepod-1 --zone us-central1-a

SpotJAX will:

  1. Create a GCS bucket for checkpoints
  2. Provision a Spot TPU
  3. SSH into all nodes and upload your code via rsync
  4. Run spotax_setup.sh if present (custom pre-install steps)
  5. Install requirements.txt dependencies
  6. Run your script with checkpoint/distributed env vars injected
  7. On preemption: clean up, provision a new TPU, and resume from last checkpoint

Project Structure

your-project/
  train.py              # Your training script
  data.py               # Data loading (optional)
  spotax_utils.py       # Checkpoint & distributed utilities (copy from examples/)
  requirements.txt      # Dependencies installed on TPU VMs
  spotax_setup.sh       # Pre-install script (optional)

spotax_utils.py

SpotJAX handles infrastructure recovery automatically — on preemption it provisions a new TPU and reruns your script. But without checkpointing, your training would restart from step 0 every time. spotax_utils.py bridges this gap: it saves model state to GCS and restores it on retry, so training resumes from where it left off.

Copy this file from examples/ into your project. It provides checkpoint management and distributed training setup with no runtime dependency on the spotax package.

from spotax_utils import CheckpointManager, get_config, setup_distributed

config = get_config()
setup_distributed(config)  # Required for multi-node (v4-16+), no-op for single-node

ckpt = CheckpointManager(config.checkpoint_dir, save_interval_steps=1000)
state, start_step = ckpt.restore_or_init(initial_state)

for step in range(start_step, max_steps):
    state = train_step(state, batch)
    ckpt.save(step, state)

    if ckpt.reached_preemption(step):
        break  # Orbax already saved checkpoint, orchestrator will retry

ckpt.close()

How checkpointing works:

  • SpotJAX enables GCP's autocheckpoint. On preemption, GCP sends SIGTERM to the VM.
  • Orbax catches SIGTERM and saves a checkpoint automatically, even outside save_interval_steps.
  • reached_preemption() detects this across all hosts and returns True so your script exits cleanly.
  • The orchestrator then provisions a new TPU and reruns. restore_or_init() picks up from the last checkpoint.

requirements.txt

Standard pip requirements. SpotJAX installs them on each TPU VM using uv with the JAX TPU releases index. Include jax[tpu] and any other dependencies your script needs:

jax[tpu]
flax
optax
orbax-checkpoint
grain

spotax_setup.sh (optional)

Runs before requirements.txt installation. Use it for things pip can't handle: system packages, building from source, patching libraries. The venv is already activated when this runs.

Environment Variables

SpotJAX injects these into your training script (read them via get_config()):

Variable Description
SPOT_CHECKPOINT_DIR GCS path for checkpoints (gs://bucket/job-id/ckpt)
SPOT_LOG_DIR GCS path for logs
SPOT_JOB_ID Unique job identifier
SPOT_IS_RESTART "true" if resuming after preemption

Multi-node only (automatically set for v4-16+, v5litepod-4+, etc.):

Variable Description
SPOT_WORKER_ID Node index (0 to N-1)
SPOT_NUM_WORKERS Total node count
JAX_COORDINATOR_ADDRESS Internal IP:port for JAX distributed

CLI Reference

spotax run <script> [options]
Option Default Description
--tpu, -t v5litepod-1 TPU type
--zone, -z us-central1-a GCP zone
--project, -p auto-detect GCP project ID
--bucket, -b spotax-{project} GCS bucket for checkpoints
--name, -n timestamp Job name
--max-retries 5 Max restart attempts
--stream-worker, -w 0 Worker index to stream logs from
--code-dir, -c script's parent dir Directory to upload

Examples

Requirements

  • Python 3.10+
  • GCP project with TPU API enabled and Spot TPU quota
  • gcloud CLI authenticated with Application Default Credentials

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spotax-0.1.0.tar.gz (63.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spotax-0.1.0-py3-none-any.whl (37.1 kB view details)

Uploaded Python 3

File details

Details for the file spotax-0.1.0.tar.gz.

File metadata

  • Download URL: spotax-0.1.0.tar.gz
  • Upload date:
  • Size: 63.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spotax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b3cc3ac4e587333035f9ab9a102c1cd69be796d9ae67c2be25b91485135aa078
MD5 84580c2050efaaa2672ff8fcc760095d
BLAKE2b-256 f111ae83a653c0355028fa7631e74e2d6e93e99e1b10ee2d02475da1a90ec974

See more details on using hashes here.

File details

Details for the file spotax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: spotax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 37.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.9 {"installer":{"name":"uv","version":"0.9.9"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"Debian GNU/Linux","version":"12","id":"bookworm","libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for spotax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 fb517e5e32bb9cb7e11491e4f553bc1870d7a503fae73ec7bacf6db1906b8f01
MD5 b8f453b6c0e367752b2ca52967517629
BLAKE2b-256 440066f7734004f63b5fe4b58735b8ada1d583b72aa8bc9d394bfeae2abfba1d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page