Skip to main content

XLA bridge for PyTorch

Project description

PyTorch/XLA

Current CI status: GitHub Actions status

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!

Please find tutorials on our GitHub page for the latest release.

Getting Started

PyTorch/XLA is now on PyPI!

To install PyTorch/XLA a new TPU VM:

pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html

To update your existing training loop, make the following changes:

-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.xla_multiprocessing as xmp

 def _mp_fn(index):
   ...

+  # Move the model paramters to your XLA device
+  model.to(xla.device())

   for inputs, labels in train_loader:
+    with xla.step():
+      # Transfer data to the XLA device. This happens asynchronously.
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
-      optimizer.step()
+      # `xm.optimizer_step` combines gradients across replicas
+      xm.optimizer_step(optimizer)

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  # xmp.spawn automatically selects the correct world size
+  xmp.spawn(_mp_fn, args=())

If you're using DistributedDataParallel, make the following changes:

 import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.xla_backend

 def _mp_fn(rank):
   ...

-  os.environ['MASTER_ADDR'] = 'localhost'
-  os.environ['MASTER_PORT'] = '12355'
-  dist.init_process_group("gloo", rank=rank, world_size=world_size)
+  # Rank and world size are inferred from the XLA device runtime
+  dist.init_process_group("xla", init_method='xla://')
+
+  model.to(xm.xla_device())
+  # `gradient_as_bucket_view=True` required for XLA
+  ddp_model = DDP(model, gradient_as_bucket_view=True)

-  model = model.to(rank)
-  ddp_model = DDP(model, device_ids=[rank])

   for inputs, labels in train_loader:
+    with xla.step():
+      inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
       optimizer.zero_grad()
       outputs = ddp_model(inputs)
       loss = loss_fn(outputs, labels)
       loss.backward()
       optimizer.step()

 if __name__ == '__main__':
-  mp.spawn(_mp_fn, args=(), nprocs=world_size)
+  xmp.spawn(_mp_fn, args=())

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).

Our comprehensive user guides are available at:

Documentation for the latest release

Documentation for master branch

PyTorch/XLA tutorials

Available docker images and wheels

For all builds and all versions of torch-xla, see our main GitHub README.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

Disclaimer

This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.

Additional Reads

You can find additional useful reading materials in

Related Projects

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl (82.2 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.28+ x86-64

torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl (82.3 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.28+ x86-64

torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl (82.3 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.28+ x86-64

File details

Details for the file torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 b92b91993e21c017ddec7d36dfb629ef868ce718a49c3f4eddcf8f435497dc06
MD5 c01f88f05e6f7b906a45d1cab461728d
BLAKE2b-256 7d5467718dbbc9a45a49a00dda58ce1fcaf1e45dd1cf7969492ee8166270400a

See more details on using hashes here.

File details

Details for the file torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 4645398328e38c4552f5a7ccf269ad6fbffb0207b69a4c68eb9570c20ed2da0e
MD5 0c4d9123b1c1582c9e60fcfb6e16184f
BLAKE2b-256 86c694d65a951738d49c3707a0df6e558133469cdfdc10cbeed8d4156e27ebec

See more details on using hashes here.

File details

Details for the file torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 3428ca4a207f300243ad6ac2afd3f0e71598b5dc21806b5f7201b6e170e8d4ab
MD5 9c598b56d0a6f005fa9988fb50fa95fb
BLAKE2b-256 dda15ba9990fbb2f798fa7585dc3845f11ee51da2b745dad50e9cadf4beb1be9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page