Skip to main content

A PyTorch backend implemented using Modular's MAX framework

Project description

Torch's MAX Backend

Simply use torch.compile, but with Modular's MAX backend.

Installation

pip install git+https://github.com/gabrieldemarmiesse/max-torch-backend.git

Quick Start

Basic Usage

from torch_max_backend import max_backend
import torch

# Compile your model with MAX backend
model = YourModel()
compiled_model = torch.compile(model, backend=max_backend)

# Use normally - now accelerated by MAX
output = compiled_model(input_tensor)

Simple Function Example

import torch
from torch_max_backend import max_backend

@torch.compile(backend=max_backend)
def simple_math(x, y):
    return x + y * 2

# Usage
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
print(simple_math(a, b))  # Accelerated execution

Training

Training works as expected

from torch_max_backend import max_backend
import torch
import torch.nn
import torch.optim
import torch.nn.functional as F

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 2)

    def forward(self, x):
        return self.linear(x)

device = "cuda"
model = MyModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

@torch.compile(backend=max_backend)
def train_step(x, y):
    model.train()
    optimizer.zero_grad()
    output = model(x)
    loss = F.mse_loss(output, y)
    loss.backward()
    optimizer.step()
    return loss

a = torch.randn(5, 3).to(device)
b = torch.randn(5, 2).to(device)

print(train_step(a, b).cpu().detach().numpy())

Device Selection

Note that currently the MAX backend does not support some older nvidia/amd gpus. So you'll need to ask MAX first if your GPU is supported before using the gpu.

from torch_max_backend import get_accelerators

# Check available accelerators
# The CPU is necessarily included in the list of accelerators
accelerators = get_accelerators()
device = "cuda" if len(list(accelerators)) >= 2 else "cpu"
model = model.to(device)

Supported Operations

The backend currently supports operations defined in aten_functions.py. You can view the mapping dictionary by importing MAPPING_TORCH_ATEN_TO_MAX.

Extending the Backend

You can add support for new PyTorch operations without cloning the repository by creating custom mappings:

from torch_max_backend import MAPPING_TORCH_ATEN_TO_MAX
from torch.ops import aten
import max_ops

# Example: Add support for a new operation
def my_custom_tanh(x):
    return max_ops.tanh(x)

# Register the operation
MAPPING_TORCH_ATEN_TO_MAX[aten.tanh] = my_custom_tanh

# Now you can use it with torch.compile
import torch
from torch_max_backend import max_backend

@torch.compile(backend=max_backend)
def my_function(x):
    return torch.tanh(x)  # Will now use your custom implementation

This approach allows you to:

  • Add missing operations your models need
  • Override existing implementations with optimized versions
  • Prototype new MAX operations before contributing them back

Performance Tips

Dynamic Shapes

For variable input sizes, mark dynamic dimensions to avoid recompiling:

from torch._dynamo import mark_dynamic

mark_dynamic(input_tensor, 0)  # batch dimension
mark_dynamic(input_tensor, 1)  # sequence length

If you don't do so, Pytorch will compile a second time when it sees a different shape, which can be costly. You can find more information about dynamic shapes in the PyTorch documentation.

Compilation Strategy

  • Use fullgraph=True when possible for better optimization. You'll get an error message if pytorch has to trigger a graph break, making it easy to fix.

Debugging

You can get various information with the following environement variables:

  • TORCH_MAX_BACKEND_PROFILE=1 to get various information about timing (time to compile, time to run, ...)
  • TORCH_MAX_BACKEND_VERBOSE=1 to display the graph(s) made by pytorch and various other information.

Contributing

Testing

# Run all tests (the first time is slow, chaching kicks in after)
uv run pytest -v -n 2 --forked

# Lint and format
uvx pre-commit run --all-files
# Or install the pre-commit hook
uvx pre-commit install

Aten ops

The list of aten ops can be found in ressources/aten_ops.txt. Sadly the documentation for those ops is non-existant. Reverse-ingeneering is the only thing that works.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_max_backend-0.1.1.tar.gz (115.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_max_backend-0.1.1-py3-none-any.whl (24.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_max_backend-0.1.1.tar.gz.

File metadata

  • Download URL: torch_max_backend-0.1.1.tar.gz
  • Upload date:
  • Size: 115.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.11

File hashes

Hashes for torch_max_backend-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3cbaccf8f78d2c16c201f7f8402096b3fa6620e40c9fbf8ba565db1b043bd2eb
MD5 5c0bfa495dab7c9674dc5fc760b2b906
BLAKE2b-256 8e15073f2c04c2d2a87d558b36fbfc3169bdd76625aba3e578f9d623b291992f

See more details on using hashes here.

File details

Details for the file torch_max_backend-0.1.1-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_max_backend-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 952e333b9041e49b676b05611246e6a45399984c79121ec88645797a12c025e1
MD5 fd68e57c69ba2c46277b47b9a8dc7bdf
BLAKE2b-256 870a4cfe3eb31b7df2efbe3c88bba3d17e415fd6d7a32a620ae792c9f3eafd67

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page