A PyTorch backend implemented using Modular's MAX framework
Project description
Torch's MAX Backend
Simply use torch.compile, but with Modular's MAX backend.
Installation
pip install git+https://github.com/gabrieldemarmiesse/max-torch-backend.git
Quick Start
Basic Usage
from torch_max_backend import max_backend
import torch
# Compile your model with MAX backend
model = YourModel()
compiled_model = torch.compile(model, backend=max_backend)
# Use normally - now accelerated by MAX
output = compiled_model(input_tensor)
Simple Function Example
import torch
from torch_max_backend import max_backend
@torch.compile(backend=max_backend)
def simple_math(x, y):
return x + y * 2
# Usage
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
print(simple_math(a, b)) # Accelerated execution
Training
Training works as expected
from torch_max_backend import max_backend
import torch
import torch.nn
import torch.optim
import torch.nn.functional as F
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
def forward(self, x):
return self.linear(x)
device = "cuda"
model = MyModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@torch.compile(backend=max_backend)
def train_step(x, y):
model.train()
optimizer.zero_grad()
output = model(x)
loss = F.mse_loss(output, y)
loss.backward()
optimizer.step()
return loss
a = torch.randn(5, 3).to(device)
b = torch.randn(5, 2).to(device)
print(train_step(a, b).cpu().detach().numpy())
Device Selection
Note that currently the MAX backend does not support some older nvidia/amd gpus. So you'll need to ask MAX first if your GPU is supported before using the gpu.
from torch_max_backend import get_accelerators
# Check available accelerators
# The CPU is necessarily included in the list of accelerators
accelerators = get_accelerators()
device = "cuda" if len(list(accelerators)) >= 2 else "cpu"
model = model.to(device)
Supported Operations
The backend currently supports operations defined in aten_functions.py. You can view the mapping dictionary by importing MAPPING_TORCH_ATEN_TO_MAX.
Extending the Backend
You can add support for new PyTorch operations without cloning the repository by creating custom mappings:
from torch_max_backend import MAPPING_TORCH_ATEN_TO_MAX
from torch.ops import aten
import max_ops
# Example: Add support for a new operation
def my_custom_tanh(x):
return max_ops.tanh(x)
# Register the operation
MAPPING_TORCH_ATEN_TO_MAX[aten.tanh] = my_custom_tanh
# Now you can use it with torch.compile
import torch
from torch_max_backend import max_backend
@torch.compile(backend=max_backend)
def my_function(x):
return torch.tanh(x) # Will now use your custom implementation
This approach allows you to:
- Add missing operations your models need
- Override existing implementations with optimized versions
- Prototype new MAX operations before contributing them back
Performance Tips
Dynamic Shapes
For variable input sizes, mark dynamic dimensions to avoid recompiling:
from torch._dynamo import mark_dynamic
mark_dynamic(input_tensor, 0) # batch dimension
mark_dynamic(input_tensor, 1) # sequence length
If you don't do so, Pytorch will compile a second time when it sees a different shape, which can be costly. You can find more information about dynamic shapes in the PyTorch documentation.
Compilation Strategy
- Use
fullgraph=Truewhen possible for better optimization. You'll get an error message if pytorch has to trigger a graph break, making it easy to fix.
Debugging
You can get various information with the following environement variables:
TORCH_MAX_BACKEND_PROFILE=1to get various information about timing (time to compile, time to run, ...)TORCH_MAX_BACKEND_VERBOSE=1to display the graph(s) made by pytorch and various other information.
Contributing
Testing
# Run all tests (the first time is slow, chaching kicks in after)
uv run pytest -v -n 2 --forked
# Lint and format
uvx pre-commit run --all-files
# Or install the pre-commit hook
uvx pre-commit install
Aten ops
The list of aten ops can be found in ressources/aten_ops.txt. Sadly
the documentation for those ops is non-existant. Reverse-ingeneering is the only thing that works.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_max_backend-0.1.1.tar.gz.
File metadata
- Download URL: torch_max_backend-0.1.1.tar.gz
- Upload date:
- Size: 115.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cbaccf8f78d2c16c201f7f8402096b3fa6620e40c9fbf8ba565db1b043bd2eb
|
|
| MD5 |
5c0bfa495dab7c9674dc5fc760b2b906
|
|
| BLAKE2b-256 |
8e15073f2c04c2d2a87d558b36fbfc3169bdd76625aba3e578f9d623b291992f
|
File details
Details for the file torch_max_backend-0.1.1-py3-none-any.whl.
File metadata
- Download URL: torch_max_backend-0.1.1-py3-none-any.whl
- Upload date:
- Size: 24.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.8.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
952e333b9041e49b676b05611246e6a45399984c79121ec88645797a12c025e1
|
|
| MD5 |
fd68e57c69ba2c46277b47b9a8dc7bdf
|
|
| BLAKE2b-256 |
870a4cfe3eb31b7df2efbe3c88bba3d17e415fd6d7a32a620ae792c9f3eafd67
|