Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Installation

TPU

To install PyTorch/XLA stable build in a new TPU VM:

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

To install PyTorch/XLA nightly build in a new TPU VM:

pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html

C++11 ABI builds

Starting from Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.

To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):

pip install torch==2.6.0+cpu.cxx11.abi \
  https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
  'torch_xla[tpu]' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html \
  -f https://download.pytorch.org/whl/torch

The above command works for Python 3.10. We additionally have Python 3.9 and 3.11 wheels:

To access C++11 ABI flavored docker image:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11

If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:

  • Pre-C++11 ABI MFU: 33%
  • C++ ABI MFU: 39%

GPU Plugin

PyTorch/XLA now provides GPU support through a plugin package similar to libtpu:

pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl

The newest stable version where PyTorch/XLA:GPU wheel is available is torch_xla 2.5. We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We understand this is important and plan to reinstate GPU support by the 2.7 release. PyTorch/XLA remains an open-source project and we welcome contributions from the community to help maintain and improve the project. To contribute, please start with the contributors guide.

Getting Started

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xla.device())

   for inputs, labels in train_loader:
+    with xla.step():
+      # Transfer data to the XLA device. This happens asynchronously.
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
-      optimizer.step()
+      # `xm.optimizer_step` combines gradients across replicas
+      xm.optimizer_step(optimizer)

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xla.launch automatically selects the correct world size
+  xla.launch(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])

   for inputs, labels in train_loader:
+    with xla.step():
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = ddp_model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
       optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xla.launch(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Reference implementations

The AI-Hypercomputer/tpu-recipes repo. contains examples for training and serving many LLM and diffusion models.

Available docker images and wheels

Python packages

PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with pip install torch_xla. To also install the Cloud TPU plugin corresponding to your installed torch_xla, install the optional tpu dependencies after installing the main build with

pip install torch_xla[tpu] \
  -f https://storage.googleapis.com/libtpu-wheels/index.html \
  -f https://storage.googleapis.com/libtpu-releases/index.html

GPU and nightly builds are available in our public GCS bucket.

Version Cloud GPU VM Wheels
2.5 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
nightly (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl
nightly (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl
nightly (CUDA 12.1 + Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl
Use nightly build before 08/13/2024 You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel of a specified date. Here is an example:
pip3 install torch==2.6.0.dev20240925+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly%2B20240925-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.6.0.dev20240925+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.

Use nightly build after 08/20/2024

You can also add yyyymmdd after torch_xla-2.6.0.dev to get the nightly wheel of a specified date. Here is an example:

pip3 install torch==2.5.0.dev20240820+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240820-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.6.0.dev20240925+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.

Use nightly build with C++11 ABI after 10/28/2024

By default, torch is built with pre-C++11 version of ABI (see https://github.com/pytorch/pytorch/issues/51039). torch_xla follows that and ships pre-C++11 builds by default. However, the lazy tensor tracing performance can be improved by building the code with C++11 ABI. As a result, we provide C++11 ABI builds for interested users to try, especially if you find your model performance bottlenecked in Python lazy tensor tracing.

You can add .cxx11 after yyyymmdd to get the C++11 ABI variant of a specific nightly wheel. Here is an example to install nightly builds from 10/28/2024:

pip3 install torch==2.6.0.dev20241028+cpu.cxx11.abi --index-url https://download.pytorch.org/whl/nightly
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028.cxx11-cp310-cp310-linux_x86_64.whl

As of 12/11/2024, the torch_xla C++11 ABI wheel is named differently and can be installed as follows:

pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241211+cxx11-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.6.0.dev20241028+cpu.cxx11.abi can be found at https://download.pytorch.org/whl/nightly/torch/.

older versions
Version Cloud TPU VMs Wheel
2.5 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.4 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.3 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.2 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (XRT + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl

Version GPU Wheel
2.5 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.2 (CUDA 12.1 + Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl
2.2 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 + CUDA 11.8 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl
nightly + CUDA 12.0 >= 2023/06/27 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Docker

Version Cloud TPU VMs Docker
2.6 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm
nightly python us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm
nightly python (C++11 ABI) us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11

To use the above dockers, please pass --privileged --net host --shm-size=16G along. Here is an example:

docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash

Version GPU CUDA 12.4 Docker
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4

Version GPU CUDA 12.1 Docker
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1
nightly at date us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD

Version GPU CUDA 11.8 + Docker
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8
2.0 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8

To run on compute instances with GPUs.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.6.0-cp311-cp311-manylinux_2_28_x86_64.whl (93.6 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl (93.6 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

torch_xla-2.6.0-cp39-cp39-manylinux_2_28_x86_64.whl (93.6 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.6.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.6.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b9a228c203a2fbbf24cf159547a1497a201329874359bda6b2ce1e3708846ce5
MD5 a2d288e800bb72724169827337a19736
BLAKE2b-256 435ba599f0a9a97c76dcef5883b550d42e65f1599ec8bc2b31c5a9ca3e1c67a9

See more details on using hashes here.

File details

Details for the file torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 fb6a34f7868e1866af7fa0a9fe853d54b104c717c563a7e75adf80613654f7cd
MD5 e4724363272b38c8206411433bda142f
BLAKE2b-256 fb1d211315fd610d97c220ae721e2958581a64af0ac21f1748e25f051277f3a8

See more details on using hashes here.

File details

Details for the file torch_xla-2.6.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.6.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 003fe69dde1e63a6feed8dbc3c042b4f7fc0926c9463124e64c3c88f3cc7a291
MD5 7b3f12e78de3b47bbbbb6680681f512c
BLAKE2b-256 26b1c73f6e2bd3346886ba0d5f1e6609552a56cf0077dd568fc988421f5e57b0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page