Skip to main content

Adapter package for torch_musa to act exactly like PyTorch CUDA

Project description

logo

torchada

English | 中文

Run your CUDA code on Moore Threads GPUs — zero code changes required

torchada is an adapter that makes torch_musa (Moore Threads GPU support for PyTorch) compatible with standard PyTorch CUDA APIs. Import it once, and your existing torch.cuda.* code works on MUSA hardware.

Why torchada?

Many PyTorch projects are written for NVIDIA GPUs using torch.cuda.* APIs. To run these on Moore Threads GPUs, you would normally need to change every cuda reference to musa. torchada eliminates this by automatically translating CUDA API calls to MUSA equivalents at runtime.

Prerequisites

  • torch_musa: You must have torch_musa installed (this provides MUSA support for PyTorch)
  • Moore Threads GPU: A Moore Threads GPU with proper driver installed

Installation

pip install torchada

# Or install from source
git clone https://github.com/MooreThreads/torchada.git
cd torchada
pip install -e .

Quick Start

import torchada  # ← Add this one line at the top
import torch

# Your existing CUDA code works unchanged:
x = torch.randn(10, 10).cuda()
print(torch.cuda.device_count())
torch.cuda.synchronize()

That's it! All torch.cuda.* APIs are automatically redirected to torch.musa.*.

What Works

Feature Example
Device operations tensor.cuda(), model.cuda(), torch.device("cuda")
Memory management torch.cuda.memory_allocated(), empty_cache()
Synchronization torch.cuda.synchronize(), Stream, Event
Mixed precision torch.cuda.amp.autocast(), GradScaler()
CUDA Graphs torch.cuda.CUDAGraph, torch.cuda.graph()
CUDA Runtime torch.cuda.cudart() → uses MUSA runtime
Profiler ProfilerActivity.CUDA → uses PrivateUse1
Custom Ops Library.impl(..., "CUDA") → uses PrivateUse1
Distributed dist.init_process_group(backend='nccl') → uses MCCL
torch.compile torch.compile(model) with all backends
C++ Extensions CUDAExtension, BuildExtension, load()
FlexAttention torch.nn.attention.flex_attention works on MUSA
ctypes Libraries ctypes.CDLL with CUDA function names → MUSA equivalents
Unified Accelerator API torch.accelerator.empty_cache(), memory_stats(), Stream, Event, ...
Triton CUDA Extra tl.extra.cudatl.extra.musa compatibility on MUSA

Examples

Mixed Precision Training

import torchada
import torch

model = MyModel().cuda()
scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    output = model(data.cuda())
    loss = criterion(output, target.cuda())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Distributed Training

import torchada
import torch.distributed as dist

# 'nccl' is automatically mapped to 'mccl' on MUSA
dist.init_process_group(backend='nccl')

CUDA Graphs

import torchada
import torch

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g):  # cuda_graph= keyword works on MUSA
    y = model(x)

torch.compile

import torchada
import torch

compiled_model = torch.compile(model.cuda(), backend='inductor')

Building C++ Extensions

import torchada  # Must import before torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

# Standard CUDAExtension works — torchada handles CUDA→MUSA translation
ext = CUDAExtension("my_ext", sources=["kernel.cu"])

Custom Ops

import torchada
import torch

my_lib = torch.library.Library("my_lib", "DEF")
my_lib.define("my_op(Tensor x) -> Tensor")
my_lib.impl("my_op", my_func, "CUDA")  # Works on MUSA!

Profiler

import torchada
import torch

# ProfilerActivity.CUDA works on MUSA
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
) as prof:
    model(x)

ctypes Library Loading

import torchada
import ctypes

# Load MUSA runtime library with CUDA function names
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaMalloc  # Automatically translates to musaMalloc

# Works with MCCL too
nccl_lib = ctypes.CDLL("libmccl.so")
func = nccl_lib.ncclAllReduce  # Automatically translates to mcclAllReduce

Unified Accelerator API (torch.accelerator)

torch.accelerator is PyTorch's unified backend-agnostic entry point. Its API surface is expanding across PyTorch releases, so APIs such as empty_cache(), memory_stats(), Stream, and Event are not yet present in torch 2.7 even though they exist on torch.musa. torchada wraps torch.accelerator so code written against the newer unified API works today:

import torchada
import torch

# APIs that exist in torch 2.7 keep their official implementation
torch.accelerator.is_available()
torch.accelerator.device_count()

# APIs missing from torch 2.7 transparently fall back to torch.musa
torch.accelerator.empty_cache()
torch.accelerator.memory_allocated()
torch.accelerator.memory_stats()
torch.accelerator.manual_seed(42)
s = torch.accelerator.Stream()
e = torch.accelerator.Event()

# Patched to delegate to torch.musa.synchronize() (the default MUSA
# implementation does not support synchronizing all streams on a device)
torch.accelerator.synchronize()

# Context managers for forward compatibility with PyTorch 2.9+
with torch.accelerator.device_index(0):
    ...
with torch.accelerator.stream(torch.musa.Stream()):
    ...

Forward compatibility: The wrapper always prefers the real torch.accelerator implementation and only falls back to torch.musa when an attribute is missing, so upgrading to a future PyTorch release that ships official implementations requires no changes on your side — you will automatically get the upstream version.

Platform Detection

import torchada
from torchada import detect_platform, Platform

platform = detect_platform()
if platform == Platform.MUSA:
    print("Running on Moore Threads GPU")
elif platform == Platform.CUDA:
    print("Running on NVIDIA GPU")

# Or use torch.version-based detection
def is_musa():
    import torch
    return hasattr(torch.version, 'musa') and torch.version.musa is not None

Performance

torchada uses aggressive caching to minimize runtime overhead. All frequently-called operations complete in under 200 nanoseconds:

Operation Overhead
torch.cuda.device_count() ~140ns
torch.cuda.Stream (attribute access) ~130ns
torch.cuda.Event (attribute access) ~130ns
_translate_device('cuda') ~140ns
torch.backends.cuda.is_built() ~155ns

For comparison, a typical GPU kernel launch takes 5,000-20,000ns. The patching overhead is negligible for real-world applications.

Operations with inherent costs (runtime calls, object creation) take 300-600ns but cannot be optimized further without changing behavior.

Known Limitation

Device type string comparisons fail on MUSA:

device = torch.device("cuda:0")  # On MUSA, this becomes musa:0
device.type == "cuda"  # Returns False!

Solution: Use torchada.is_gpu_device():

import torchada

if torchada.is_gpu_device(device):  # Works on both CUDA and MUSA
    ...
# Or: device.type in ("cuda", "musa")

API Reference

Function Description
detect_platform() Returns Platform.CUDA, Platform.MUSA, or Platform.CPU
is_musa_platform() Returns True if running on MUSA
is_cuda_platform() Returns True if running on CUDA
is_gpu_device(device) Returns True if device is CUDA or MUSA
CUDA_HOME Path to CUDA/MUSA installation
cuda_to_musa_name(name) Convert cudaXxxmusaXxx
nccl_to_mccl_name(name) Convert ncclXxxmcclXxx
cublas_to_mublas_name(name) Convert cublasXxxmublasXxx
curand_to_murand_name(name) Convert curandXxxmurandXxx

Note: torch.cuda.is_available() is intentionally NOT redirected — it returns False on MUSA. This allows proper platform detection. For GPU availability checks, see the has_gpu() pattern in examples/migrate_existing_project.md.

Note: The name conversion utilities are exported for manual use, but ctypes.CDLL is automatically patched to translate function names when loading MUSA libraries.

C++ Extension Symbol Mapping

When building C++ extensions, torchada automatically translates CUDA symbols to MUSA:

CUDA MUSA
cudaMalloc musaMalloc
cudaStream_t musaStream_t
cublasHandle_t mublasHandle_t
at::cuda at::musa
c10::cuda c10::musa
#include <cuda/*> #include <musa/*>

See src/torchada/_mapping.py for the complete mapping table (380+ mappings).

Integrating torchada into Your Project

Step 1: Add Dependency

# pyproject.toml or requirements.txt
torchada>=0.1.60

Step 2: Conditional Import

# At your application entry point
def is_musa():
    import torch
    return hasattr(torch.version, "musa") and torch.version.musa is not None

if is_musa():
    import torchada  # noqa: F401

# Rest of your code uses torch.cuda.* as normal

Step 3: Extend Feature Flags (if applicable)

# Include MUSA in GPU capability checks
if is_nvidia() or is_musa():
    ENABLE_FLASH_ATTENTION = True

Step 4: Fix Device Type Checks (if applicable)

# Instead of: device.type == "cuda"
# Use: device.type in ("cuda", "musa")
# Or: torchada.is_gpu_device(device)

Projects Using torchada

Project Category Status Tracking
SGLang Model Serving ✅ Merged -
vLLM-MUSA Model Serving ✅ Merged
vLLM-Omni Model Serving (Omni) ✅ Merged -
Xinference Model Serving ✅ Merged
LightLLM Model Serving ✅ Merged
LightX2V Image/Video Generation ✅ Merged
Chitu Model Serving ✅ Merged
Mooncake KVCache ✅ Merged -
ComfyUI Image/Video Generation 🚧 In Progress ComfyUI#11618

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchada-0.1.60.tar.gz (128.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchada-0.1.60-py3-none-any.whl (136.8 kB view details)

Uploaded Python 3

File details

Details for the file torchada-0.1.60.tar.gz.

File metadata

  • Download URL: torchada-0.1.60.tar.gz
  • Upload date:
  • Size: 128.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.16

File hashes

Hashes for torchada-0.1.60.tar.gz
Algorithm Hash digest
SHA256 f7e867a809bd3ea9b46126b6f7ef8a4871047e9f3164538a5aba234948a01690
MD5 d92d40972086c41be47f5201251e3acb
BLAKE2b-256 fa16cdea397e4363fc20e41cfe6872c10dd3f171e61255fd648dcef7aba87b22

See more details on using hashes here.

File details

Details for the file torchada-0.1.60-py3-none-any.whl.

File metadata

  • Download URL: torchada-0.1.60-py3-none-any.whl
  • Upload date:
  • Size: 136.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.16

File hashes

Hashes for torchada-0.1.60-py3-none-any.whl
Algorithm Hash digest
SHA256 522d186ab59ad890ece903cfe665ae34e4fd2bfa0a14ed9b7f7e42a66efe8dcb
MD5 6cfe3e44fd37946369a7813966b72980
BLAKE2b-256 daf24c44dafc5af3f4ca7405a0a9e9ab719492b8f6e2d3e867a09fb2dc10179e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page