Skip to main content

torchax is a library for running Jax and PyTorch together

Project description

torchax: Running PyTorch on TPU via JAX

torchax is a backend for PyTorch, allowing users to run PyTorch on Google Cloud TPUs. torchax is also a library for providing graph-level interoperability between PyTorch and JAX.

This means, with torchax you can:

  • Run PyTorch code on TPUs with as little as 2 lines of code change.
  • Call a JAX function from a PyTorch function, passing in jax.Arrays.
  • Call a PyTorch function from a JAX function, passing in a torch.Tensors.
  • Use JAX features such as jax.grad, optax, and GSPMD to train a PyTorch model.
  • Use a PyTorch model as feature extractor and use it with a JAX model. etc etc.

Install

First install torch CPU:

# On Linux.
pip install torch --index-url https://download.pytorch.org/whl/cpu

# Or on Mac.
pip install torch

Then install JAX for the accelerator you want to use:

# On Google Cloud TPU.
pip install -U jax[tpu]

# Or, on GPU machines.
pip install -U jax[cuda12]

# Or, on Linux CPU machines or Macs (see the note below).
pip install -U jax

NOTE: if you like metal support for Apple devices then install the metal version of JAX: https://developer.apple.com/metal/jax/

Finally install torchax:

# Install pre-built torchax.
pip install torchax

# Or, install torchax from source.
pip install git+https://github.com/pytorch/xla.git#subdirectory=torchax

Run a model

Now let's execute a model under torchax. We'll start with a simple 2-layer model. In theory, we can use any instance of torch.nn.Module.

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()

# Execute this model using torch.
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))

To execute this model with torchax, we need to enable torchax to capture PyTorch ops:

import torchax
torchax.enable_globally()

Then, we can use a jax device:

inputs = torch.randn(3, 3, 28, 28, device='jax')
m = MyModel().to('jax')
res = m(inputs)
print(type(res))  # outputs torchax.tensor.Tensor

torchax.tensor.Tensor is a torch.Tensor subclass that holds a jax.Array. You can inspect that JAX array with res.jax().

What is happening behind the scene

We took the approach detailed in the new device recipe by Alban (@albanD), using jax.Array for raw_data.

In other words, when a torch op is executed inside an env context manager, which is enabled by torchax.enable_globally(), we will swap out the implementation of that op with JAX.

When a model's constructor runs, it will call some tensor constructor, such as torch.rand, torch.ones, or torch.zeros to create its weights. When torchax is enabled, these constructors will create a torchax.tensor.Tensor, which contains a jax.Array.

Then, each subsequent op will extract the jax.Array, call the op's JAX implementation, and wrap the result back into a torchax.tensor.Tensor,

See more at how it works and
ops registry.

Executing with jax.jit

The above script will execute the model using eager mode JAX as the backend. This does allow executing torch models on TPUs, but is often slower than what we can achieve with jax.jit.

jax.jit is a function that takes a JAX function (i.e. a function that takes JAX arrays and returns JAX arrays) into a compiled (thus faster) version of the same function.

We have made a jax_jit decorator that would accomplish the same with functions that takes and returns torch.Tensors. To use this, the first step is to create a functional version of this model: this means the parameters should be passed in as input instead of being attributes of the class:

def model_func(param, inputs):
  return torch.func.functional_call(m, param, inputs)

Here we use torch.func.functional_call from PyTorch to replace the model weights with param and then call the model. This is roughly equivalent to:

def model_func(param, inputs):
  m.load_state_dict(param)
  return m(*inputs)

Now, we can apply jax_jit on module_func:

from torchax.interop import jax_jit

model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))

See more examples at eager_mode.py and the examples folder.

To ease the idiom of creating functional model and calling it with parameters, we also created the JittableModule helper class. It lets us rewrite the above as:

from torchax.interop import JittableModule

m_jitted = JittableModule(m)
res = m_jitted(...)

The first time m_jitted is called, it will trigger jax.jit to compile the compile for the given input shapes. Subsequent calls with the same input shapes will be fast as the compilation is cached.

Citation

@software{torchax,
  author = {Han Qi, Chun-nien Chan, Will Cromar, Manfei Bai, Kevin Gleanson},
  title = {torchax: PyTorch on TPU and JAX interoperability},
  url = {https://github.com/pytorch/xla/tree/master/torchax}
  version = {0.0.4},
  date = {2025-02-24},
}

Maintainers & Contributors:

This library is created and maintained by the PyTorch/XLA team at Google Cloud.

It benefitted from many direct and indirect contributions outside of the team. Many of them done by fellow Googlers using Google's 20% project policy. Others by partner teams at Google and other companies.

Here is the list of contributors by 2025-02-25.

Han Qi (qihqi), PyTorch/XLA
Manfei Bai (manfeibai), PyTorch/XLA
Will Cromar (will-cromar), Meta
Milad Mohammadi (miladm), PyTorch/XLA
Siyuan Liu (lsy323), PyTorch/XLA
Bhavya Bahl (bhavya01), PyTorch/XLA
Pei Zhang (zpcore), PyTorch/XLA
Yifei Teng (tengyifei), PyTorch/XLA
Chunnien Chan (chunnienc), Google, ODML
Alban Desmaison (albanD), Meta, PyTorch
Simon Teo (simonteozw), Google (20%)
David Huang (dvhg), Google (20%)
Barni Seetharaman (barney-s), Google (20%)
Anish Karthik (anishfish2), Google (20%)
Yao Gu (guyao), Google (20%)
Yenkai Wang (yenkwang), Google (20%)
Greg Shikhman (commander), Google (20%)
Matin Akhlaghinia (matinehAkhlaghinia), Google (20%)
Tracy Chen (tracych477), Google (20%)
Matthias Guenther (mrguenther), Google (20%)
WenXin Dong (wenxindongwork), Google (20%)
Kevin Gleason (GleasonK), Google, StableHLO
Nupur Baghel (nupurbaghel), Google (20%)
Gwen Mittertreiner (gmittert), Google (20%)
Zeev Melumian (zmelumian), Lightricks
Vyom Sharma (vyom1611), Google (20%)
Shitong Wang (ShitongWang), Adobe
Rémi Doreau (ayshiff), Google (20%)
Lance Wang (wang2yn84), Google, CoreML
Hossein Sarshar (hosseinsarshar), Google (20%)
Daniel Vega-Myhre (danielvegamyhre), Google (20%)
Tianqi Fan (tqfan28), Google (20%)
Jim Lin (jimlinntu), Google (20%)
Fanhai Lu (FanhaiLu1), Google Cloud
DeWitt Clinton (dewitt), Google PyTorch
Aman Gupta (aman2930), Google (20%)

Project details


Release history Release notifications | RSS feed

This version

0.0.6

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchax-0.0.6.tar.gz (303.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchax-0.0.6-py3-none-any.whl (100.8 kB view details)

Uploaded Python 3

File details

Details for the file torchax-0.0.6.tar.gz.

File metadata

  • Download URL: torchax-0.0.6.tar.gz
  • Upload date:
  • Size: 303.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchax-0.0.6.tar.gz
Algorithm Hash digest
SHA256 4a2e8ed9349103629529e7282a8412f77a5a5bb836ab276f55948f40b5cf7b6d
MD5 86b048dbfce5fa29cc71d537713e8fb9
BLAKE2b-256 14a76f7ac6d7cf6ac06811f318ec1b0fed1461012b7176ba7c2a2442470ebe96

See more details on using hashes here.

File details

Details for the file torchax-0.0.6-py3-none-any.whl.

File metadata

  • Download URL: torchax-0.0.6-py3-none-any.whl
  • Upload date:
  • Size: 100.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for torchax-0.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 c210ace546bec1a7f0cda5f1801f49029ab0712af736cfa9e1218abaa6686bb3
MD5 4fe7ef776cd477ac842992d53ee98186
BLAKE2b-256 caee2e214d01a5a05c95d51b0d67395f1364207dc81f97271f0a317cf6e8e391

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page