Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs

Installation

TPU

To install PyTorch/XLA stable build in a new TPU VM: Note: Builds are available for Python 3.11 to 3.13; please use one of the supported versions.

# - for venv
# python3.11 -m venv py311
# - for conda
# conda create -n py311 python=3.11

pip install torch==2.8.0 'torch_xla[tpu]==2.8.0'
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]'

C++11 ABI builds

As of 03/18/2025 and starting from Pytorch/XLA 2.7 release, C++11 ABI builds are the default and we no longer provide wheels built with pre-C++11 ABI.

In Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.

To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):

pip install torch==2.6.0+cpu.cxx11.abi \
  https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
  'torch_xla[tpu]' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html \
  -f https://download.pytorch.org/whl/torch

To access C++11 ABI flavored docker image:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11

If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:

  • Pre-C++11 ABI MFU: 33%
  • C++ ABI MFU: 39%

Github Doc Map

Our github contains many useful docs on working with different aspects of PyTorch XLA, here is a list of useful docs spread around our repository:

Getting Started

Following here are guides for two modes:

  • Single process: one Python interpreter controlling a single GPU/TPU at a time
  • Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs found on the system

Another mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on the system. Multi processing is more complex, and is not compatible with SPMD. This tutorial does not dive into SPMD. For more on that, check our SPMD guide.

Simple single process

To update your exisitng training loop, make the following changes:

+import torch_xla

 def train(model, training_data, ...):
   ...
   for inputs, labels in train_loader:
+    with torch_xla.step():
       inputs, labels = training_data[i]
+      inputs, labels = inputs.to('xla'), labels.to('xla')
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
       optimizer.step()

+  torch_xla.sync()
   ...

 if __name__ == '__main__':
   ...
+  # Move the model paramters to your XLA device
+  model.to('xla')
   train(model, training_data, ...)
   ...

The changes above should get your model to train on the TPU.

Multi processing

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla
+import torch_xla.core.xla_model as xm

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to('xla')

   for inputs, labels in train_loader:
+    with torch_xla.step():
+      # Transfer data to the XLA device. This happens asynchronously.
+      inputs, labels = inputs.to('xla'), labels.to('xla')
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
-      optimizer.step()
+      # `xm.optimizer_step` combines gradients across replicas
+      xm.optimizer_step(optimizer)

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # torch_xla.launch automatically selects the correct world size
+  torch_xla.launch(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to('xla')
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])

   for inputs, labels in train_loader:
+    with torch_xla.step():
+      inputs, labels = inputs.to('xla'), labels.to('xla')
       optimizer.zero_grad()
       outputs = ddp_model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
       optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  torch_xla.launch(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Reference implementations

The AI-Hypercomputer/tpu-recipes repo. contains examples for training and serving many LLM and diffusion models.

Available docker images and wheels

Python packages

PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with pip install torch_xla. To also install the Cloud TPU plugin corresponding to your installed torch_xla, install the optional tpu dependencies after installing the main build with

pip install 'torch_xla[tpu]'

Use nightly build

You can also add yyyymmdd like torch_xla-2.8.0.devyyyymmdd (or the latest dev version) to get the nightly wheel of a specified date. Here is an example:

pip3 install torch==2.8.0.dev20250423+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250423-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.8.0.dev20250423+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.

older versions
Version Cloud TPU VMs Wheel
2.8 (Python 3.12) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl
2.7 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.6 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.4 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.3 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.2 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (XRT + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl

Docker

NOTE: Since PyTorch/XLA 2.7, all builds will use the C++11 ABI by default

Version Cloud TPU VMs Docker
2.8 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.8.0_3.12_tpuvm
2.7 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_tpuvm
2.6 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm
2.6 (C++11 ABI) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm
nightly python us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm

To use the above dockers, please pass --privileged --net host --shm-size=16G along. Here is an example:

docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash

Version GPU CUDA 12.6 Docker
2.7 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.7.0_3.10_cuda_12.6

Version GPU CUDA 12.4 Docker
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4

Version GPU CUDA 12.1 Docker
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1
nightly at date us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD

Version GPU CUDA 11.8 + Docker
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8
2.0 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8

To run on compute instances with GPUs.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl (88.9 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.28+ x86-64

torch_xla-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl (88.9 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.28+ x86-64

torch_xla-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl (88.9 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl (88.9 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 be2388ba3fa875892eb02d8a3e80ed2cd3878f99806639caebdf2b522f9fdb6c
MD5 faf6df279d70370db494efcb897e148c
BLAKE2b-256 ec843adfb9cb0edcd79a79c72211a300c91c3a670d248af6452e0acb61b112aa

See more details on using hashes here.

File details

Details for the file torch_xla-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 08dd1cbd227ef712d8af4879cdf1560b82a153475f6a97d5b749b710b3ad7e9d
MD5 98e8ad03b4b72a14ad244804bef7e83c
BLAKE2b-256 efcefbca5a6c0b8eab3169e640a0e4d25be89c27d413aa2badd1e2c0e757c4bf

See more details on using hashes here.

File details

Details for the file torch_xla-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 92af0ed9aff7983a35df4eb6b744a33ea881d0cbcb674e89b46f11285d3062cb
MD5 73327206936c58820e2e80449d98b8a6
BLAKE2b-256 eadfcdfd8fcc180291dcd169b9229d3445942b0ae2ccaa2f0092f69ebc78e5c5

See more details on using hashes here.

File details

Details for the file torch_xla-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 548adfde13b5e3539cb451a5c0170db8e3d4364fe79de14139f353401e58116f
MD5 7d6cf97aebe5434332f77c7bb86f0ce9
BLAKE2b-256 b0b67bd4df2a3ac236d25295390d6619263c35ee0251edf3e05444b531f73b62

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page