Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package. See our main README for all available builds, including GPU builds.

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Take a look at one of our Kaggle notebooks to get started:

Getting Started

To install PyTorch/XLA a new VM:

pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xm.xla_device())
+
+  # MpDeviceLoader preloads data to the XLA device
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
-    optimizer.step()
+
+    # `xm.optimizer_step` combines gradients across replocas
+    xm.optimizer_step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xmp.spawn automatically selects the correct world size
+  xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank, world_size):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=tpu` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])
+  xla_train_loader = pl.MpDeviceLoader(train_loader, xm.xla_device())

-  for inputs, labels in train_loader:
+  for inputs, labels in xla_train_loader:
     optimizer.zero_grad()
     outputs = ddp_model(inputs)
     loss = loss_fn(outputs, labels)
     loss.backward()
     optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, GPU, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

For all builds and all versions of torch-xla, see our main GitHub README.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.1.0-cp311-cp311-manylinux_2_28_x86_64.whl (81.1 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (81.1 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

torch_xla-2.1.0-cp39-cp39-manylinux_2_28_x86_64.whl (81.1 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.28+ x86-64

torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl (81.1 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.1.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.1.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f6b19c3f0b5c6ce0e1e1c13bdca1471e42cdae7220eaa17697335db7c879365f
MD5 7e8f0283e9ff7b613478240a37f019bd
BLAKE2b-256 5f780a7b28e27134c923477d85dfc9f5e76d616978c3607a2a16957a969267c9

See more details on using hashes here.

File details

Details for the file torch_xla-2.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.1.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 2f5873eb62f8d1388368a4fdefd116e6fc742f767d269ccab65c709d72c83ca7
MD5 e594bc2d7a3b42f8415c55bb63ba79c5
BLAKE2b-256 795c829e31b746a484451736c520187ab715a087f6b716f61c09d2ada25c137e

See more details on using hashes here.

File details

Details for the file torch_xla-2.1.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.1.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4a1ac188f7ae563e056c321e8504e50a3d18ca21f894da806bdb5634b96ec817
MD5 c50824b73274433f5a26c441cc4b7963
BLAKE2b-256 1ae9322559b09f3544565d6cfc3e033a3dee457691471b25702260fce35b1c47

See more details on using hashes here.

File details

Details for the file torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 0a29ef8ca1235f644417aa288a23d3e98cbd0558de4743e2a853fe6944a059d9
MD5 3da5344340a285c786aaeb2a9e8ffa71
BLAKE2b-256 a11075fa8344d54a0f716720f98d3528e71d2ceb5baba9782a96c920d45c10f7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page