Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package.

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Getting Started

To install PyTorch/XLA a new TPU VM:

pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xm.xla_device())
+
+  # MpDeviceLoader preloads data to the XLA device
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
-    optimizer.step()
+
+    # `xm.optimizer_step` combines gradients across replicas
+    xm.optimizer_step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xmp.spawn automatically selects the correct world size
+  xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank, world_size):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=True` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = ddp_model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
     optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

For all builds and all versions of torch-xla, see our main GitHub README.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.2.0-cp311-cp311-manylinux_2_28_x86_64.whl (85.5 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl (85.5 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

torch_xla-2.2.0-cp39-cp39-manylinux_2_28_x86_64.whl (85.5 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.28+ x86-64

torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl (85.5 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.2.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.2.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 a7c188d7e9cee3c764c81152c28702413a235003cdd29663321ea7f43564d425
MD5 73ce3e68f97891d72700a75ee79bfcdf
BLAKE2b-256 b4b87487b37f0c8d66beddf047045ae1e7a2b89905081c20104cef3af022d83e

See more details on using hashes here.

File details

Details for the file torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 da91dec2aad9194cedf7e70a7058ef8a37f80a6bcab3f54e644793b5fe17e358
MD5 955b971a9f6b6d8374616d4163ef102d
BLAKE2b-256 72d475b9b9fe171e75aaf15c890716f5a9e25fab659b9cc31a559e525a421062

See more details on using hashes here.

File details

Details for the file torch_xla-2.2.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.2.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f2563b55621e29f00cb947ca394e1414e09b22212d795eb5c7114d94de587b1e
MD5 7d8cbb84738656fed2a17816a799d24b
BLAKE2b-256 060762514264fb76e885527c43e2d2d32b3555505ff0faa3500a1ccabf42b145

See more details on using hashes here.

File details

Details for the file torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 bf3fef6f56bb52b9180be5463334ac9fc69cb83e4bd6de2ef82a320bcc3012fd
MD5 76765278384947b9a97dde319be72c7d
BLAKE2b-256 e15e0e1a467c2eae69e6c27ddd2ef56a5971a2fc6712530e53f026b59976bc8c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page