Skip to main content

A PyTorch backend implemented using Modular's MAX framework

Project description

Torch's MAX Backend

Simply use torch.compile, but with Modular's MAX backend.

Installation

pip install torch-max-backend

Quick Start

Basic Usage

from torch_max_backend import max_backend
import torch

# Compile your model with MAX backend
model = YourModel()
compiled_model = torch.compile(model, backend=max_backend)

# Use normally - now accelerated by MAX
output = compiled_model(input_tensor)

Simple Function Example

import torch
from torch_max_backend import max_backend

@torch.compile(backend=max_backend)
def simple_math(x, y):
    return x + y * 2

# Usage
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
print(simple_math(a, b))  # Accelerated execution

Training

Training works as expected

from torch_max_backend import max_backend
import torch
import torch.nn
import torch.optim
import torch.nn.functional as F

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 2)

    def forward(self, x):
        return self.linear(x)

device = "cuda"
model = MyModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

@torch.compile(backend=max_backend)
def train_step(x, y):
    model.train()
    optimizer.zero_grad()
    output = model(x)
    loss = F.mse_loss(output, y)
    loss.backward()
    optimizer.step()
    return loss

a = torch.randn(5, 3).to(device)
b = torch.randn(5, 2).to(device)

print(train_step(a, b).cpu().detach().numpy())

Device Selection

Note that currently the MAX backend does not support some older nvidia/amd gpus. So you'll need to ask MAX first if your GPU is supported before using the gpu.

from torch_max_backend import get_accelerators

# Check available accelerators
# The CPU is necessarily included in the list of accelerators
accelerators = get_accelerators()
device = "cuda" if len(list(accelerators)) >= 2 else "cpu"
model = model.to(device)

Supported Operations

The backend currently supports operations defined in aten_functions.py. You can view the mapping dictionary by importing MAPPING_TORCH_ATEN_TO_MAX.

Extending the Backend

You can add support for new PyTorch operations without cloning the repository by creating custom mappings:

from torch_max_backend import MAPPING_TORCH_ATEN_TO_MAX
from torch.ops import aten
import max_ops

# Example: Add support for a new operation
def my_custom_tanh(x):
    return max_ops.tanh(x)

# Register the operation
MAPPING_TORCH_ATEN_TO_MAX[aten.tanh] = my_custom_tanh

# Now you can use it with torch.compile
import torch
from torch_max_backend import max_backend

@torch.compile(backend=max_backend)
def my_function(x):
    return torch.tanh(x)  # Will now use your custom implementation

This approach allows you to:

  • Add missing operations your models need
  • Override existing implementations with optimized versions
  • Prototype new MAX operations before contributing them back

Performance Tips

Dynamic Shapes

For variable input sizes, mark dynamic dimensions to avoid recompiling:

from torch._dynamo import mark_dynamic

mark_dynamic(input_tensor, 0)  # batch dimension
mark_dynamic(input_tensor, 1)  # sequence length

If you don't do so, Pytorch will compile a second time when it sees a different shape, which can be costly. You can find more information about dynamic shapes in the PyTorch documentation.

Compilation Strategy

  • Use fullgraph=True when possible for better optimization. You'll get an error message if pytorch has to trigger a graph break, making it easy to fix.

Debugging

You can get various information with the following environement variables:

  • TORCH_MAX_BACKEND_PROFILE=1 to get various information about timing (time to compile, time to run, ...)
  • TORCH_MAX_BACKEND_VERBOSE=1 to display the graph(s) made by pytorch and various other information.
  • TORCH_MAX_BACKEND_BEARTYPE=0 to disable type checking. By default, everything in the package is type-checked at runtime. But it may lead to errors when actually the code is valid (and the type hint is wrong). You can try disabling the type-checking then to see if the bug goes away. Feel free to open a bug report in any case! Type errors should never happen and are a sign of an internal bug.

Contributing

Testing

# Run all tests (the first time is slow, chaching kicks in after)
uv run pytest -v -n 2 --forked

# Lint and format
uvx pre-commit run --all-files
# Or install the pre-commit hook
uvx pre-commit install

You can try to run all the pretrained models to make sure we're compatible with

./pretrained_models/run_all.sh
# or for example
uv run pretrained_model/gpt2.sh

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_max_backend-0.2.0.tar.gz (217.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_max_backend-0.2.0-py3-none-any.whl (27.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_max_backend-0.2.0.tar.gz.

File metadata

  • Download URL: torch_max_backend-0.2.0.tar.gz
  • Upload date:
  • Size: 217.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.12

File hashes

Hashes for torch_max_backend-0.2.0.tar.gz
Algorithm Hash digest
SHA256 85870358e1211abd3dc1c44cdf7be1cdd1aab19e04aff57fb08996ec9a5d4a31
MD5 6d233e3cba32510147082a93f5143a26
BLAKE2b-256 8f1c32d722bf6c4d94811e5fa5deb3ecedae5c6396d610b4ef0aff0296d0f812

See more details on using hashes here.

File details

Details for the file torch_max_backend-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_max_backend-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7007d182d57a4b78fe866d6a39aba6a553d5928666a72716cb8a331793bad16e
MD5 299153a75dd74acc2a6d118089f2defa
BLAKE2b-256 d8dc80af96ed887bec9779c737358b3d005124bcc0ca4eff3b0ad3f00c84de9e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page