Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package.

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Please find tutorials on our GitHub page for the latest release.

Installation

TPU

To install PyTorch/XLA a new TPU VM:

pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 -f https://storage.googleapis.com/libtpu-releases/index.html

GPU Plugin (beta)

PyTorch/XLA now provides GPU support through a plugin package similar to libtpu:

pip torch~=2.3.0 torch_xla~=2.3.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.3.0-py3-none-any.whl

To use the plugin, set XLA_REGISTER_INSTALLED_PLUGINS=1 or call torch_xla.experimental.plugins.use_dynamic_plugins() in your script.

The CUDA plugin is considered beta for the 2.3 release, and it will become standard in the near future. For more information about device plugins, see issue #6242.

For a list of all torch_xla packages with statically-linked CUDA support, see our main README.

Getting Started

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xm.xla_device())
+
+  # MpDeviceLoader preloads data to the XLA device
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
-    optimizer.step()
+
+    # `xm.optimizer_step` combines gradients across replicas
+    xm.optimizer_step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xmp.spawn automatically selects the correct world size
+  xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=True` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = ddp_model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
     optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

For all builds and all versions of torch-xla, see our main GitHub README.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (84.3 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl (84.3 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

torch_xla-2.3.0-cp39-cp39-manylinux_2_28_x86_64.whl (84.3 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.28+ x86-64

torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl (84.2 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 e0b2f88baf3373b9c0a4f351488dbb9b4b007b52c1c66f65b65e1984b5f0f227
MD5 f933027832c26a47db44de98afe25877
BLAKE2b-256 b0fcf80ff75f213c362e9de06593cefdd587d490a4e1ff5d0ac8303bb8adba7b

See more details on using hashes here.

File details

Details for the file torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 262876ab0e95a4ecd131afa33a89ad7f94544f878a74198ee52fcf723af39e6f
MD5 18345d8611e3c2ff46f5238eb6f9dcc7
BLAKE2b-256 17cd7b7b1ef8189804deaa8ee5c13693496115d8cb05956908f60ba5ad1ff141

See more details on using hashes here.

File details

Details for the file torch_xla-2.3.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.3.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 8282e0ff92f42e18e22f65c0ec5a17acd5bc51728b1fdeb6b4ccade3a313c6ac
MD5 d0d73a75dc7f281fb6e19d30e2f7fe3b
BLAKE2b-256 f13e58ec36daa5172e5589fca5346f251a36d07fbff5dcd6436a7ff4ca2beb86

See more details on using hashes here.

File details

Details for the file torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 6678b2bea3baeda916cdb314d5ad190eeb388e71a4de04ccfa948ab74d6d4c72
MD5 f5f0e39e7c2ea5f25d130af18e70cfc0
BLAKE2b-256 04c909dc136615bc2875ea10a802453507950a3f2c6e2b75cfcb49d104c4c78e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page