A library to execute PyTorch on TPU
Project description
torchxla2
Install
Currently this is only source-installable. Requires Python version >= 3.10.
NOTE:
Please don't install torch-xla from instructions in https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md . In particular, the following are not needed:
- There is no need to build pytorch/pytorch from source.
- There is no need to clone pytorch/xla project inside of pytorch/pytorch git checkout.
TorchXLA2 and torch-xla have different installation instructions, please follow the instructions below from scratch (fresh venv / conda environment.)
1. Installing torch_xla2
The following instructions assume you are in the torch_xla2 directory:
$ git clone https://github.com/pytorch/xla.git
$ cd xla/experimental/torch_xla2
1.0 (recommended) Make a virtualenv / conda env
If you are using VSCode, then you can create a new environment from
UI. Select the
dev-requirements.txt when asked to install project dependencies.
Otherwise create a new environment from the command line.
# Option 1: venv
python -m venv create my_venv
source my_venv/bin/activate
# Option 2: conda
conda create --name <your_name> python=3.10
conda activate <your_name>
# Either way, install the dev requirements.
pip install -r dev-requirements.txt
Note: dev-requirements.txt will install the CPU-only version of PyTorch.
1.1 Install this package
If you want to install torch_xla2 without the jax dependency and use the jax dependency from torch_xla:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -e .
Otherwise, install torch_xla2 from source for your platform:
pip install -e .[cpu]
pip install -e .[cuda]
pip install -e .[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
1.2 (optional) verify installation by running tests
pip install -r test-requirements.txt
pytest test
Run a model
Now let's execute a model under torch_xla2. We'll start with a simple 2-layer model
it can be in theory any instance of torch.nn.Module.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
m = MyModel()
# Execute this model using torch
inputs = torch.randn(3, 3, 28, 28)
print(m(inputs))
This model m contains 2 parts: the weights that is stored inside of the model
and it's submodules (nn.Linear).
To execute this model with torch_xla2; we need construct and run the model
under an environment that captures pytorch ops and swaps them with TPU equivalent.
To create this environment: use
import torch_xla2
env = torch_xla2.default_env()
Then, execute the instiation of the model, as well as evaluation of model,
using env as a context manager:
with env:
inputs = torch.randn(3, 3, 28, 28)
m = MyModel()
res = m(inputs)
print(type(res)) # outputs XLATensor2
What is happening behind the scene:
When a torch op is executed inside of env context manager, we can swap out the
implementation of that op with a version that runs on TPU.
When a model's constructor runs, it will call some tensor constructor, such as
torch.rand, torch.ones or torch.zeros etc to create its weights. Those
ops are captured by env too and placed directly on TPU.
See more at how_it_works and ops registry.
What if I created model outside of env.
So if you have
m = MyModel()
outside of env, then regular torch ops will run when creating this model.
Then presumably the model's weights will be on CPU (as instances of torch.Tensor).
To move this model into XLA device, one can use env.to_xla() function.
i.e.
m2 = env.to_xla(m)
inputs = env.to_xla(inputs)
with env:
res = m2(inputs)
NOTE that we also need to move inputs to xla using .to_xla.
to_xla works with all pytrees of torch.Tensor.
Executing with jax.jit
The above script will execute the model using eager mode Jax as backend. This
does allow executing torch models on TPU, but is often slower than what we can
achieve with jax.jit.
jax.jit is a function that takes a Jax function (i.e. a function that takes jax array
and returns jax array) into the same function, but faster.
We have made the jax_jit decorator that would accomplish the same with functions
that takes and returns torch.Tensor. To use this, the first step is to create
a functional version of this model: this means the parameters should be passed in
as input instead of being attributes on class:
def model_func(param, inputs):
return torch.func.functional_call(m, param, inputs)
Here we use torch.func.functional_call
from PyTorch to replace the model
weights with param, then call the model. This is equivalent to:
def model_func(param, inputs):
m.load_state_dict(param)
return m(*inputs)
Now, we can apply jax_jit
from torch_xla2.extra import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
See more examples at eager_mode.py and the (examples folder)[examples/]
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file hantoshi-0.0.2.tar.gz.
File metadata
- Download URL: hantoshi-0.0.2.tar.gz
- Upload date:
- Size: 94.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e650889eccfa46bbf94a73482c872a92a5eca82dce850170d40718228737fb7d
|
|
| MD5 |
eb0b167d0db420063b587525ad1d271c
|
|
| BLAKE2b-256 |
517faf3a118a38eab779c60580def66e66661c4df12c54d0abecdb5e6063b71f
|
File details
Details for the file hantoshi-0.0.2-py3-none-any.whl.
File metadata
- Download URL: hantoshi-0.0.2-py3-none-any.whl
- Upload date:
- Size: 4.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.10.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bdba68e7a480fa665500182a73829116b3ce835d4b2a80cc85c3cf3d575a0d35
|
|
| MD5 |
be36c8b6473d0da002bd7e576ef53fce
|
|
| BLAKE2b-256 |
7d7ae3cb8765a20a11e29a354379e41c1d92aaf2a171df63ca407a595a34ba29
|