Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Installation

TPU

To install PyTorch/XLA stable build in a new TPU VM:

pip install torch==2.5.1 torch_xla[tpu]==2.5.1 -f https://storage.googleapis.com/libtpu-releases/index.html

To install PyTorch/XLA nightly build in a new TPU VM:

pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html

GPU Plugin

PyTorch/XLA now provides GPU support through a plugin package similar to libtpu:

pip install torch==2.5.1 torch_xla==2.5.1 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.1-py3-none-any.whl

Getting Started

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xla.device())

   for inputs, labels in train_loader:
+    with xla.step():
+      # Transfer data to the XLA device. This happens asynchronously.
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
-      optimizer.step()
+      # `xm.optimizer_step` combines gradients across replicas
+      xm.optimizer_step(optimizer)

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xla.launch automatically selects the correct world size
+  xla.launch(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=True` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])

   for inputs, labels in train_loader:
+    with xla.step():
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = ddp_model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
       optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xla.launch(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

Python packages

PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with pip install torch_xla. To also install the Cloud TPU plugin corresponding to your installed torch_xla, install the optional tpu dependencies after installing the main build with

pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

GPU and nightly builds are available in our public GCS bucket.

Version Cloud GPU VM Wheels
2.5.1 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl
nightly (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl
nightly (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl
Use nightly build before 08/13/2024 You can also add `+yyyymmdd` after `torch_xla-nightly` to get the nightly wheel of a specified date. Here is an example:
pip3 install torch==2.5.0.dev20240613+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly%2B20240613-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.5.0.dev20240613+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.

Use nightly build after 08/20/2024

You can also add yyyymmdd after torch_xla-2.5.0.dev to get the nightly wheel of a specified date. Here is an example:

pip3 install torch==2.5.0.dev20240820+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240820-cp310-cp310-linux_x86_64.whl

The torch wheel version 2.5.0.dev20240820+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.

older versions
Version Cloud TPU VMs Wheel
2.5 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.4 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.3 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.2 (Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (XRT + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 (Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl

Version GPU Wheel
2.5.1 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl
2.5.1 (CUDA 12.4 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.5 (CUDA 12.4 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.9) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.4 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.3 (CUDA 12.1 + Python 3.11) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl
2.2 (CUDA 12.1 + Python 3.8) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl
2.2 (CUDA 12.1 + Python 3.10) https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
2.1 + CUDA 11.8 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl
nightly + CUDA 12.0 >= 2023/06/27 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl

Docker

Version Cloud TPU VMs Docker
2.5.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm
nightly python us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm

To use the above dockers, please pass --privileged --net host --shm-size=16G along. Here is an example:

docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash

Version GPU CUDA 12.4 Docker
2.5.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_cuda_12.4
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4

Version GPU CUDA 12.1 Docker
2.5.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_cuda_12.1
2.5 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1
2.4 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1
2.3 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1
2.2 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1
nightly us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1
nightly at date us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD

Version GPU CUDA 11.8 + Docker
2.1 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8
2.0 us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8

To run on compute instances with GPUs.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl (90.6 MB view details)

Uploaded CPython 3.11 manylinux: glibc 2.28+ x86-64

torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl (90.6 MB view details)

Uploaded CPython 3.10 manylinux: glibc 2.28+ x86-64

torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl (90.6 MB view details)

Uploaded CPython 3.9 manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.5.1-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 1c425b57c5636d0088153ce4b049fd36426fdd0e72abed80a259974b382a19a6
MD5 3e2908a27d5ab20f4a7980a10583257b
BLAKE2b-256 18d5785af7f9c944a3c34cabece0f3ec640f55480a9a92fa926dba4472613957

See more details on using hashes here.

File details

Details for the file torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b922c3fc76232fe5f69ededd580e85696bcd6029f32c8a964a0fc614fa5594b8
MD5 b420b1d1e1cf79711035e60c9f3f7bc2
BLAKE2b-256 5bc35022a4c9215032bea1b0cb88cde322b4cb45961f3ac45ec737578dfae945

See more details on using hashes here.

File details

Details for the file torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.5.1-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 7519ef6a94120c6c02cdfb0c7588cfd6523bc6c67287b2d2183039fc61b99299
MD5 c7a21ad09f6c0ba1b496a18daa0105b6
BLAKE2b-256 8d8cd001a5b9c432f563e4875d14e3e377ea72f07ccabccb272d2cdfdd5b4046

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page