Skip to main content

torchax is a library for running Jax and PyTorch together

Project description

torchax: Running PyTorch on TPU via JAX

Docs page: https://google.github.io/torchax/ Discord Discussion Channel: https://discord.gg/JqeJqGPyzC

torchax is a backend for PyTorch that allows users to run PyTorch programs on Google Cloud TPUs. It also provides graph-level interoperability between PyTorch and JAX.

With torchax, you can:

  • Run PyTorch code on TPUs with minimal code changes.
  • Call JAX functions from PyTorch, passing in jax.Arrays.
  • Call PyTorch functions from JAX, passing in torch.Tensors.
  • Use JAX features like jax.grad, optax, and GSPMD to train PyTorch models.
  • Use a PyTorch model as a feature extractor with a JAX model.

Install

First, install the CPU version of PyTorch:

# On Linux
pip install torch --index-url https://download.pytorch.org/whl/cpu

# On Mac
pip install torch

Next, install JAX for your desired accelerator:

# On Google Cloud TPU
pip install -U jax[tpu]

# On GPU machines
pip install -U jax[cuda12]

# On Linux CPU machines or Macs (see the note below)
pip install -U jax

Note: For Apple devices, you can install the Metal version of JAX for hardware acceleration.

Finally, install torchax:

# Install from PyPI
pip install torchax

# Or, install torchax from source.
pip install git+https://github.com/google/torchax

Running a Model

To execute a model with torchax, start with any torch.nn.Module. Here’s an example with a simple 2-layer model:

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()

# Execute this model using torch.
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))

To execute this model with torchax, we need to enable torchax to capture PyTorch ops:

import torchax
torchax.enable_globally()

Then, we can use a jax device:

inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel().to('jax')
res = m(inputs)
print(type(res))  # outputs torchax.tensor.Tensor
print(res.jax()) # print the underlying Jax Array

torchax.tensor.Tensor is a torch.Tensor subclass that holds a jax.Array. You can inspect that JAX array with res.jax().

Although the code appears to be standard PyTorch, it's actually running on JAX.

How It Works

torchax uses a torch.Tensor subclass, torchax.tensor.Tensor, which holds a jax.Array and overrides the __torch_dispatch__ method. When a PyTorch operation is executed within the torchax environment (enabled by torchax.enable_globally()), the implementation of that operation is swapped with its JAX equivalent.

When a model is instantiated, tensor constructors like torch.rand create torchax.tensor.Tensor objects containing jax.Arrays. Subsequent operations extract the jax.Array, call the corresponding JAX implementation, and wrap the result back into a torchax.tensor.Tensor.

For more details, see the How It Works and Ops Registry documentation.

Executing with jax.jit

While torchax can run models in eager mode, jax.jit can be used for better performance. jax.jit is a decorator that compiles a function that takes and returns torch.Tensors into a faster, JAX-compiled version.

To use jax.jit, you first need a functional version of your model where parameters are passed as inputs:

def model_func(param, inputs):
  return torch.func.functional_call(m, param, inputs)

Here we use torch.func.functional_call from PyTorch to replace the model weights with param and then call the model. This is roughly equivalent to:

def model_func(param, inputs):
  m.load_state_dict(param)
  return m(*inputs)

Now, we can apply jax_jit on module_func:

from torchax.interop import jax_jit

model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))

See more examples at eager_mode.py and the examples folder.

To ease the idiom of creating functional model and calling it with parameters, we also created the JittableModule helper class. It lets us rewrite the above as:

from torchax.interop import JittableModule

m_jitted = JittableModule(m)
res = m_jitted(...)

The first time m_jitted is called, it will trigger jax.jit to compile the compile for the given input shapes. Subsequent calls with the same input shapes will be fast as the compilation is cached.

Saving and Loading Checkpoints

You can save and load your training state using torchax.save_checkpoint and torchax.load_checkpoint. The state can be a dictionary containing the model's weights, optimizer state, and any other relevant information.

import torchax
import torch
import optax

# Assume model, optimizer, and other states are defined
model = MyModel()
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(model.parameters())
weights = model.parameters()
buffers = model.buffers()
epoch = 10

state = {
    'weights': weights,
    'buffers': buffers,
    'opt_state': opt_state,
    'epoch': epoch,
}

# Save checkpoint
torchax.save_checkpoint(state, '/path/to/checkpoint.pt')

# Load checkpoint
loaded_state = torchax.load_checkpoint('/path/to/checkpoint.pt')

# Restore state
model.load_state_dict(loaded_state['weights'])
opt_state = loaded_state['opt_state']
epoch = loaded_state['epoch']

Citation

@software{torchax,
  author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
  title = {torchax: PyTorch on TPU and JAX interoperability},
  url = {https://github.com/pytorch/xla/tree/master/torchax}
  version = {0.0.4},
  date = {2025-02-24},
}

Maintainers & Contributors

This library is maintained by a team within Google Cloud. It has benefited from many contributions from both inside and outside the team.

Thank you to recent contributors.

Han Qi (qihqi), PyTorch/XLA
Manfei Bai (manfeibai), PyTorch/XLA
Will Cromar (will-cromar), Meta
Milad Mohammadi (miladm), PyTorch/XLA
Siyuan Liu (lsy323), PyTorch/XLA
Bhavya Bahl (bhavya01), PyTorch/XLA
Pei Zhang (zpcore), PyTorch/XLA
Yifei Teng (tengyifei), PyTorch/XLA
Chunnien Chan (chunnienc), Google, ODML
Alban Desmaison (albanD), Meta, PyTorch
Simon Teo (simonteozw), Google (20%)
David Huang (dvhg), Google (20%)
Barni Seetharaman (barney-s), Google (20%)
Anish Karthik (anishfish2), Google (20%)
Yao Gu (guyao), Google (20%)
Yenkai Wang (yenkwang), Google (20%)
Greg Shikhman (commander), Google (20%)
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
Tracy Chen (tracych477), Google (20%)
Matthias Guenther (mrguenther), Google (20%)
WenXin Dong (wenxindongwork), Google (20%)
Kevin Gleason (GleasonK), Google, StableHLO
Nupur Baghel (nupurbaghel), Google (20%)
Gwen Mittertreiner (gmittert), Google (20%)
Zeev Melumian (zmelumian), Lightricks
Vyom Sharma (vyom1611), Google (20%)
Shitong Wang (ShitongWang), Adobe
Rémi Doreau (ayshiff), Google (20%)
Lance Wang (wang2yn84), Google, CoreML
Hossein Sarshar (hosseinsarshar), Google (20%)
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
Tianqi Fan (tqfan28), Google (20%)
Jim Lin (jimlinntu), Google (20%)
Fanhai Lu (FanhaiLu1), Google Cloud
DeWitt Clinton (dewitt), Google PyTorch
Aman Gupta (aman2930), Google (20%)

A special thank you to @albanD for the initial inspiration for torchax.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchax-0.0.11.dev20251223-py3-none-any.whl (115.3 kB view details)

Uploaded Python 3

File details

Details for the file torchax-0.0.11.dev20251223-py3-none-any.whl.

File metadata

File hashes

Hashes for torchax-0.0.11.dev20251223-py3-none-any.whl
Algorithm Hash digest
SHA256 8a05e3cfa78c09584fb06160831f1cda8b00c63fffe5574dfefbd01fc85ac683
MD5 189cbe450b2c092624a91a1caefef775
BLAKE2b-256 92041f2738b698b32106a98ade76031b1c04613e935dd35989ea524631f627c3

See more details on using hashes here.

Provenance

The following attestation bundles were made for torchax-0.0.11.dev20251223-py3-none-any.whl:

Publisher: nightly.yml on google/torchax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page