Adapter package for torch_musa to act exactly like PyTorch CUDA
Project description
torchada
English | 中文
Run your CUDA code on Moore Threads GPUs — zero code changes required
torchada is an adapter that makes torch_musa (Moore Threads GPU support for PyTorch) compatible with standard PyTorch CUDA APIs. Import it once, and your existing torch.cuda.* code works on MUSA hardware.
Why torchada?
Many PyTorch projects are written for NVIDIA GPUs using torch.cuda.* APIs. To run these on Moore Threads GPUs, you would normally need to change every cuda reference to musa. torchada eliminates this by automatically translating CUDA API calls to MUSA equivalents at runtime.
Prerequisites
- torch_musa: You must have torch_musa installed (this provides MUSA support for PyTorch)
- Moore Threads GPU: A Moore Threads GPU with proper driver installed
Installation
pip install torchada
# Or install from source
git clone https://github.com/MooreThreads/torchada.git
cd torchada
pip install -e .
Quick Start
import torchada # ← Add this one line at the top
import torch
# Your existing CUDA code works unchanged:
x = torch.randn(10, 10).cuda()
print(torch.cuda.device_count())
torch.cuda.synchronize()
That's it! All torch.cuda.* APIs are automatically redirected to torch.musa.*.
What Works
| Feature | Example |
|---|---|
| Device operations | tensor.cuda(), model.cuda(), torch.device("cuda") |
| Memory management | torch.cuda.memory_allocated(), empty_cache() |
| Synchronization | torch.cuda.synchronize(), Stream, Event |
| Mixed precision | torch.cuda.amp.autocast(), GradScaler() |
| CUDA Graphs | torch.cuda.CUDAGraph, torch.cuda.graph() |
| CUDA Runtime | torch.cuda.cudart() → uses MUSA runtime |
| Profiler | ProfilerActivity.CUDA → uses PrivateUse1 |
| Custom Ops | Library.impl(..., "CUDA") → uses PrivateUse1 |
| Distributed | dist.init_process_group(backend='nccl') → uses MCCL |
| torch.compile | torch.compile(model) with all backends |
| C++ Extensions | CUDAExtension, BuildExtension, load() |
| FlexAttention | torch.nn.attention.flex_attention works on MUSA |
| ctypes Libraries | ctypes.CDLL with CUDA function names → MUSA equivalents |
| Unified Accelerator API | torch.accelerator.empty_cache(), memory_stats(), Stream, Event, ... |
Examples
Mixed Precision Training
import torchada
import torch
model = MyModel().cuda()
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(data.cuda())
loss = criterion(output, target.cuda())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Distributed Training
import torchada
import torch.distributed as dist
# 'nccl' is automatically mapped to 'mccl' on MUSA
dist.init_process_group(backend='nccl')
CUDA Graphs
import torchada
import torch
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g): # cuda_graph= keyword works on MUSA
y = model(x)
torch.compile
import torchada
import torch
compiled_model = torch.compile(model.cuda(), backend='inductor')
Building C++ Extensions
import torchada # Must import before torch.utils.cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# Standard CUDAExtension works — torchada handles CUDA→MUSA translation
ext = CUDAExtension("my_ext", sources=["kernel.cu"])
Custom Ops
import torchada
import torch
my_lib = torch.library.Library("my_lib", "DEF")
my_lib.define("my_op(Tensor x) -> Tensor")
my_lib.impl("my_op", my_func, "CUDA") # Works on MUSA!
Profiler
import torchada
import torch
# ProfilerActivity.CUDA works on MUSA
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
) as prof:
model(x)
ctypes Library Loading
import torchada
import ctypes
# Load MUSA runtime library with CUDA function names
lib = ctypes.CDLL("libmusart.so")
func = lib.cudaMalloc # Automatically translates to musaMalloc
# Works with MCCL too
nccl_lib = ctypes.CDLL("libmccl.so")
func = nccl_lib.ncclAllReduce # Automatically translates to mcclAllReduce
Unified Accelerator API (torch.accelerator)
torch.accelerator is PyTorch's unified backend-agnostic entry point. Its API
surface is expanding across PyTorch releases, so APIs such as empty_cache(),
memory_stats(), Stream, and Event are not yet present in torch 2.7 even
though they exist on torch.musa. torchada wraps torch.accelerator so code
written against the newer unified API works today:
import torchada
import torch
# APIs that exist in torch 2.7 keep their official implementation
torch.accelerator.is_available()
torch.accelerator.device_count()
# APIs missing from torch 2.7 transparently fall back to torch.musa
torch.accelerator.empty_cache()
torch.accelerator.memory_allocated()
torch.accelerator.memory_stats()
torch.accelerator.manual_seed(42)
s = torch.accelerator.Stream()
e = torch.accelerator.Event()
# Patched to delegate to torch.musa.synchronize() (the default MUSA
# implementation does not support synchronizing all streams on a device)
torch.accelerator.synchronize()
# Context managers for forward compatibility with PyTorch 2.9+
with torch.accelerator.device_index(0):
...
with torch.accelerator.stream(torch.musa.Stream()):
...
Forward compatibility: The wrapper always prefers the real
torch.accelerator implementation and only falls back to torch.musa when an
attribute is missing, so upgrading to a future PyTorch release that ships
official implementations requires no changes on your side — you will
automatically get the upstream version.
Platform Detection
import torchada
from torchada import detect_platform, Platform
platform = detect_platform()
if platform == Platform.MUSA:
print("Running on Moore Threads GPU")
elif platform == Platform.CUDA:
print("Running on NVIDIA GPU")
# Or use torch.version-based detection
def is_musa():
import torch
return hasattr(torch.version, 'musa') and torch.version.musa is not None
Performance
torchada uses aggressive caching to minimize runtime overhead. All frequently-called operations complete in under 200 nanoseconds:
| Operation | Overhead |
|---|---|
torch.cuda.device_count() |
~140ns |
torch.cuda.Stream (attribute access) |
~130ns |
torch.cuda.Event (attribute access) |
~130ns |
_translate_device('cuda') |
~140ns |
torch.backends.cuda.is_built() |
~155ns |
For comparison, a typical GPU kernel launch takes 5,000-20,000ns. The patching overhead is negligible for real-world applications.
Operations with inherent costs (runtime calls, object creation) take 300-600ns but cannot be optimized further without changing behavior.
Known Limitation
Device type string comparisons fail on MUSA:
device = torch.device("cuda:0") # On MUSA, this becomes musa:0
device.type == "cuda" # Returns False!
Solution: Use torchada.is_gpu_device():
import torchada
if torchada.is_gpu_device(device): # Works on both CUDA and MUSA
...
# Or: device.type in ("cuda", "musa")
API Reference
| Function | Description |
|---|---|
detect_platform() |
Returns Platform.CUDA, Platform.MUSA, or Platform.CPU |
is_musa_platform() |
Returns True if running on MUSA |
is_cuda_platform() |
Returns True if running on CUDA |
is_gpu_device(device) |
Returns True if device is CUDA or MUSA |
CUDA_HOME |
Path to CUDA/MUSA installation |
cuda_to_musa_name(name) |
Convert cudaXxx → musaXxx |
nccl_to_mccl_name(name) |
Convert ncclXxx → mcclXxx |
cublas_to_mublas_name(name) |
Convert cublasXxx → mublasXxx |
curand_to_murand_name(name) |
Convert curandXxx → murandXxx |
Note: torch.cuda.is_available() is intentionally NOT redirected — it returns False on MUSA. This allows proper platform detection. For GPU availability checks, see the has_gpu() pattern in examples/migrate_existing_project.md.
Note: The name conversion utilities are exported for manual use, but ctypes.CDLL is automatically patched to translate function names when loading MUSA libraries.
C++ Extension Symbol Mapping
When building C++ extensions, torchada automatically translates CUDA symbols to MUSA:
| CUDA | MUSA |
|---|---|
cudaMalloc |
musaMalloc |
cudaStream_t |
musaStream_t |
cublasHandle_t |
mublasHandle_t |
at::cuda |
at::musa |
c10::cuda |
c10::musa |
#include <cuda/*> |
#include <musa/*> |
See src/torchada/_mapping.py for the complete mapping table (380+ mappings).
Integrating torchada into Your Project
Step 1: Add Dependency
# pyproject.toml or requirements.txt
torchada>=0.1.52
Step 2: Conditional Import
# At your application entry point
def is_musa():
import torch
return hasattr(torch.version, "musa") and torch.version.musa is not None
if is_musa():
import torchada # noqa: F401
# Rest of your code uses torch.cuda.* as normal
Step 3: Extend Feature Flags (if applicable)
# Include MUSA in GPU capability checks
if is_nvidia() or is_musa():
ENABLE_FLASH_ATTENTION = True
Step 4: Fix Device Type Checks (if applicable)
# Instead of: device.type == "cuda"
# Use: device.type in ("cuda", "musa")
# Or: torchada.is_gpu_device(device)
Projects Using torchada
| Project | Category | Status | Tracking |
|---|---|---|---|
| Xinference | Model Serving | ✅ Merged | — |
| LightLLM | Model Serving | ✅ Merged | — |
| LightX2V | Image/Video Generation | ✅ Merged | — |
| Chitu | Model Serving | ✅ Merged | — |
| vLLM-MUSA | Model Serving | ✅ Merged | — |
| SGLang | Model Serving | 🚧 In Progress | SGLang#16565 |
| ComfyUI | Image/Video Generation | 🚧 In Progress | ComfyUI#11618 |
| vLLM-Omni | Model Serving (Omni) | 🚧 In Progress | vLLM-Omni#2347 |
License
MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchada-0.1.52.tar.gz.
File metadata
- Download URL: torchada-0.1.52.tar.gz
- Upload date:
- Size: 81.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8336a0eeb72801c94d4eef05049e399b81195ddd52121e622f0307b23a575d92
|
|
| MD5 |
8020a2fe2f41f4e8b53377190470e1a2
|
|
| BLAKE2b-256 |
7b2866d7aea71e85a13fbc170ebc0a7986a033412a729e16e145d6aa291d6fd5
|
File details
Details for the file torchada-0.1.52-py3-none-any.whl.
File metadata
- Download URL: torchada-0.1.52-py3-none-any.whl
- Upload date:
- Size: 49.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
885f1930bfefa777594528db4e3eb0132a9994e3583ca2bd91a2824d79b59417
|
|
| MD5 |
4c3ceb4d383b46a34ee9e81d8b0fe880
|
|
| BLAKE2b-256 |
bb7b19fa5b252d1571c37c496cf6bba1021f50a98ef40e3912152557c157b4c5
|