Skip to main content

Adapter package for torch_musa to act exactly like PyTorch CUDA

Project description

torchada

Adapter package for torch_musa to act exactly like PyTorch CUDA

torchada provides a unified interface that works transparently on both NVIDIA GPUs (CUDA) and Moore Threads GPUs (MUSA). Write your code once using standard PyTorch CUDA APIs, and it will run on MUSA hardware without any changes.

Features

  • Zero Code Changes: Just import torchada once, then use standard torch.cuda.* APIs
  • Automatic Platform Detection: Detects whether you're running on CUDA or MUSA
  • Transparent Device Mapping: tensor.cuda() and tensor.to("cuda") work on MUSA
  • Extension Building: Standard torch.utils.cpp_extension works on MUSA after importing torchada
  • Source Code Porting: Automatic CUDA → MUSA symbol mapping for C++/CUDA extensions

Installation

pip install torchada

# Or install from source
git clone https://github.com/yeahdongcn/torchada.git
cd torchada
pip install -e .

Quick Start

Basic Usage

import torchada  # Import once to apply patches - that's it!
import torch

# Use standard torch.cuda APIs - they work on both CUDA and MUSA:
if torch.cuda.is_available():
    device = torch.device("cuda")
    tensor = torch.randn(10, 10).cuda()
    model = MyModel().cuda()

    # All torch.cuda.* APIs work transparently
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Device name: {torch.cuda.get_device_name()}")
    torch.cuda.synchronize()

Building C++ Extensions

# setup.py - Use standard torch imports!
import torchada  # Import first to apply patches
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CUDA_HOME

print(f"Building with CUDA/MUSA home: {CUDA_HOME}")

ext_modules = [
    CUDAExtension(
        name="my_extension",
        sources=[
            "my_extension.cpp",
            "my_extension_kernel.cu",
        ],
        extra_compile_args={
            "cxx": ["-O3"],
            "nvcc": ["-O3"],  # Automatically mapped to mcc on MUSA
        },
    ),
]

setup(
    name="my_package",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
)

JIT Compilation

import torchada  # Import first to apply patches
from torch.utils.cpp_extension import load

# Load extension at runtime (works on both CUDA and MUSA)
my_extension = load(
    name="my_extension",
    sources=["my_extension.cpp", "my_extension_kernel.cu"],
    verbose=True,
)

Mixed Precision Training

import torchada  # Import first to apply patches
import torch

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = torch.cuda.amp.GradScaler()

for data, target in dataloader:
    data, target = data.cuda(), target.cuda()

    with torch.cuda.amp.autocast():
        output = model(data)
        loss = criterion(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

Distributed Training

import torchada  # Import first to apply patches
import torch.distributed as dist

# Use 'nccl' backend as usual - torchada maps it to 'mccl' on MUSA
dist.init_process_group(backend='nccl')

CUDA Graphs

import torchada  # Import first to apply patches
import torch

# Use standard torch.cuda.CUDAGraph - works on MUSA too
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    y = model(x)

Platform Detection

torchada automatically detects the platform:

import torchada
from torchada import detect_platform, Platform

platform = detect_platform()
if platform == Platform.MUSA:
    print("Running on Moore Threads GPU")
elif platform == Platform.CUDA:
    print("Running on NVIDIA GPU")

# Or use convenience functions
if torchada.is_musa_platform():
    print("MUSA platform detected")

What Gets Patched

After import torchada, the following standard PyTorch APIs work on MUSA:

Standard Import Works On MUSA
torch.cuda.* ✅ All APIs
torch.cuda.amp.* ✅ autocast, GradScaler
torch.cuda.CUDAGraph ✅ Maps to MUSAGraph
torch.distributed (backend='nccl') ✅ Uses MCCL
torch.utils.cpp_extension.* ✅ CUDAExtension, BuildExtension

API Reference

torchada

Function Description
detect_platform() Returns the detected platform (CUDA, MUSA, or CPU)
is_musa_platform() Check if running on MUSA
is_cuda_platform() Check if running on CUDA
get_device_name() Get device name string ("cuda", "musa", or "cpu")

torch.cuda (after importing torchada)

All standard torch.cuda APIs work, including:

  • is_available(), device_count(), current_device(), set_device()
  • memory_allocated(), memory_reserved(), empty_cache()
  • synchronize(), Stream, Event, CUDAGraph
  • amp.autocast(), amp.GradScaler()

torch.utils.cpp_extension (after importing torchada)

Symbol Description
CUDAExtension Creates CUDA or MUSA extension based on platform
CppExtension Creates C++ extension (no GPU code)
BuildExtension Build command for extensions
CUDA_HOME Path to CUDA/MUSA installation
load() JIT compile and load extension

Symbol Mapping

torchada automatically maps CUDA symbols to MUSA equivalents when building extensions:

CUDA MUSA
cudaMalloc musaMalloc
cudaMemcpy musaMemcpy
cudaStream_t musaStream_t
cublasHandle_t mublasHandle_t
curandState murandState
at::cuda at::musa
c10::cuda c10::musa
... ...

See src/torchada/_mapping.py for the complete mapping table.

License

MIT License

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchada-0.1.0.tar.gz (30.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchada-0.1.0-py3-none-any.whl (22.6 kB view details)

Uploaded Python 3

File details

Details for the file torchada-0.1.0.tar.gz.

File metadata

  • Download URL: torchada-0.1.0.tar.gz
  • Upload date:
  • Size: 30.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for torchada-0.1.0.tar.gz
Algorithm Hash digest
SHA256 21301b7ec22c78a28d7753a833aa2315a2dc34d2ae220cb16494202b760eb8e4
MD5 c9820b4818adae06bc9e20874664b624
BLAKE2b-256 27d06960fb5a7fdce1dd8a73e8e30980b2cf1f611a2f751582d20cb77a1660ca

See more details on using hashes here.

File details

Details for the file torchada-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchada-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 22.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for torchada-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8446879728c26ba3fa18821385246f5c5a2a548c5ac5a956af7441578ccf0bb4
MD5 116eecea56ae01e1a2900f295aee7511
BLAKE2b-256 61a5909b1b01e211bedb770db22e36b5e8ad69976e5e26677459fed92c145464

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page