Skip to main content

A library to execute PyTorch on TPU

Project description

torchxla2

Install

Currently this is only source-installable. Requires Python version >= 3.10.

NOTE:

Please don't install torch-xla from instructions in https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md . In particular, the following are not needed:

  • There is no need to build pytorch/pytorch from source.
  • There is no need to clone pytorch/xla project inside of pytorch/pytorch git checkout.

TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.)

1. Installing torch_xla2

The following instructions assume you are in the torch_xla2 directory:

$ git clone https://github.com/pytorch/xla.git
$ cd xla/experimental/torch_xla2

1.0 (recommended) Make a virtualenv / conda env

If you are using VSCode, then you can create a new environment from UI. Select the dev-requirements.txt when asked to install project dependencies.

Otherwise create a new environment from the command line.

# Option 1: venv
python -m venv create my_venv
source my_venv/bin/activate

# Option 2: conda
conda create --name <your_name> python=3.10
conda activate <your_name>

# Either way, install the dev requirements.
pip install -r dev-requirements.txt

Note: dev-requirements.txt will install the CPU-only version of PyTorch.

1.1 Install this package

If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla:

pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -e .

Otherwise, install torch_xla2 from source for your platform:

pip install -e .[cpu]
pip install -e .[cuda]
pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

1.2 (optional) verify installation by running tests

pip install -r test-requirements.txt
pytest test

Run a model

Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model it can be in theory any instance of torch.nn.Module.

import torch
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

m = MyModel()

# Execute this model using torch
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))

This model m contains 2 parts: the weights that is stored inside of the model and it's submodules (nn.Linear).

To execute this model with torch_xla2; we need construct and run the model under an environment that captures pytorch ops and swaps them with TPU equivalent.

To create this environment: use

import torch_xla2

env = torch_xla2.default_env() 

Then, execute the instiation of the model, as well as evaluation of model, using env as a context manager:

with env:
  inputs = torch.randn(3, 3, 28, 28)
  m = MyModel()
  res = m(inputs)
  print(type(res))  # outputs XLATensor2

What is happening behind the scene:

When a torch op is executed inside of env context manager, we can swap out the implementation of that op with a version that runs on TPU. When a model's constructor runs, it will call some tensor constructor, such as torch.rand, torch.ones or torch.zeros etc to create its weights. Those ops are captured by env too and placed directly on TPU.

See more at how_it_works and ops registry.

What if I created model outside of env.

So if you have

m = MyModel()

outside of env, then regular torch ops will run when creating this model. Then presumably the model's weights will be on CPU (as instances of torch.Tensor).

To move this model into XLA device, one can use env.to_xla() function.

i.e.

m2 = env.to_xla(m)
inputs = env.to_xla(inputs)

with env:
  res = m2(inputs)

NOTE that we also need to move inputs to xla using .to_xla. to_xla works with all pytrees of torch.Tensor.

Executing with jax.jit

The above script will execute the model using eager mode Jax as backend. This does allow executing torch models on TPU, but is often slower than what we can achieve with jax.jit.

jax.jit is a function that takes a Jax function (i.e. a function that takes jax array and returns jax array) into the same function, but faster.

We have made the jax_jit decorator that would accomplish the same with functions that takes and returns torch.Tensor. To use this, the first step is to create a functional version of this model: this means the parameters should be passed in as input instead of being attributes on class:

def model_func(param, inputs):
  return torch.func.functional_call(m, param, inputs)

Here we use torch.func.functional_call from PyTorch to replace the model weights with param, then call the model. This is equivalent to:

def model_func(param, inputs):
  m.load_state_dict(param)
  return m(*inputs)

Now, we can apply jax_jit

from torch_xla2.extra import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))

See more examples at eager_mode.py and the (examples folder)[examples/]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hantoshi-0.0.2.tar.gz (94.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

hantoshi-0.0.2-py3-none-any.whl (4.6 kB view details)

Uploaded Python 3

File details

Details for the file hantoshi-0.0.2.tar.gz.

File metadata

  • Download URL: hantoshi-0.0.2.tar.gz
  • Upload date:
  • Size: 94.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for hantoshi-0.0.2.tar.gz
Algorithm Hash digest
SHA256 e650889eccfa46bbf94a73482c872a92a5eca82dce850170d40718228737fb7d
MD5 eb0b167d0db420063b587525ad1d271c
BLAKE2b-256 517faf3a118a38eab779c60580def66e66661c4df12c54d0abecdb5e6063b71f

See more details on using hashes here.

File details

Details for the file hantoshi-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: hantoshi-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 4.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.14

File hashes

Hashes for hantoshi-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 bdba68e7a480fa665500182a73829116b3ce835d4b2a80cc85c3cf3d575a0d35
MD5 be36c8b6473d0da002bd7e576ef53fce
BLAKE2b-256 7d7ae3cb8765a20a11e29a354379e41c1d92aaf2a171df63ca407a595a34ba29

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page