Skip to main content

Hardware abstraction layer for PyTorch across NVIDIA, AMD, and TPU backends

Project description

TorchBridge

Your PyTorch code is locked to one GPU vendor. CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a hardware abstraction layer that makes your models run on NVIDIA, AMD, and TPU without code changes, and validates that outputs match across backends.

Version Tests Cloud GPU AWS A10G GCP T4 Python PyTorch

What is TorchBridge?

PyTorch lets you build models. TorchBridge lets you run them anywhere.

Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a unified API that detects your hardware and adapts automatically.

Your model code
      |
  TorchBridge HAL
      |
  +---------+---------+---------+
  | NVIDIA  |   AMD   |   TPU   |
  | CUDA    |  ROCm   |   XLA   |
  +---------+---------+---------+

What it does:

  • Backend detection -- automatically identifies available accelerators
  • Vendor adapters -- translates unified API calls to vendor-specific operations
  • Precision management -- handles FP32/FP16/BF16/FP8 across backends
  • Memory optimization -- gradient checkpointing, activation offloading, memory pooling
  • Checkpoint portability -- save on one backend, load on another
  • Distributed training -- tensor/pipeline/data parallelism across backend types

Quick Start

git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -r requirements.txt

# Verify
PYTHONPATH=src python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"

Detect Hardware

from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())

Run on Any Backend

import torch
from torchbridge import TorchBridgeConfig, UnifiedManager

config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)

model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)

optimized_model = manager.optimize(model)

Validate

from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")

Supported Backends

Backend Hardware Precision Status
NVIDIA B200, H100, H200, A100, L4, T4 FP4, FP8, BF16, FP16, FP32 Production
AMD MI350X, MI325X, MI300X, MI200 FP4, FP8, BF16, FP16, FP32 Production
TPU v4, v5e, v5p, v6e, v7 BF16, FP32 Production
CPU x86, ARM (Apple Silicon) FP32, BF16 Fallback

See Hardware Matrix for full details.

Key Features

Backend Detection and Adaptation

Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.

Vendor Adapters

Each backend implements a common BaseBackend interface. Your code calls manager.optimize(model) and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, XLA on TPU.

Precision Management

Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with torch.amp autocast works across all backends.

Memory Optimization

Gradient checkpointing, activation offloading, optimizer state sharding, and memory pooling. These work consistently whether you're on a single GPU or a multi-node cluster.

Checkpoint Portability

Save a checkpoint on NVIDIA hardware, load it on AMD or TPU. TorchBridge handles device mapping and dtype conversion.

Distributed Training

Tensor parallelism, pipeline parallelism, and FSDP with a unified API. The same distributed training script runs on NVIDIA DGX, AMD Instinct, or TPU pods.

Code Examples

Backend-Agnostic Training

import torch
from torchbridge.backends import BackendFactory, detect_best_backend

backend = BackendFactory.create(detect_best_backend())
device = backend.device

model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.amp.autocast(device.type):
        loss = criterion(model(inputs), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

Hardware Capability Queries

from torchbridge.backends.nvidia import NVIDIABackend

nvidia = NVIDIABackend()
print(nvidia.get_device_info())  # GPU model, compute capability, memory
print(nvidia.supports_fp8())     # True on H100+

Cross-Backend Model Export

from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors

sample_input = torch.randn(1, 768)

export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")

Project Structure

src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   └── tpu/           #   Google TPU/XLA backend
├── hardware/          # Hardware detection and abstraction
├── precision/         # FP8 training and precision management
├── attention/         # Attention mechanisms (unified API)
├── advanced_memory/   # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/        # Model export and serving
├── monitoring/        # Metrics, logging, health checks
├── optimizations/     # Optimization patterns and strategies
├── core/              # Core config, management, optimized layers
├── cli/               # Command-line tools
├── models/            # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/        # Cross-backend validation
└── utils/             # Utilities and profiling

Cloud GPU Validation

All 5 use cases validated on real GPU hardware across AWS and GCP:

Use Case AWS A10G GCP L4 Description
Export Pipeline PASS PASS TorchScript, ONNX, SafeTensors export with validation
LLM Optimization PASS PASS Qwen3/DeepSeek optimization with backend-specific tuning
CI/CD Validation PASS PASS Diagnostics, benchmarks, cross-backend checks
Backend Training PASS PASS AMP training with auto backend detection
Cross-Backend Validation PASS PASS Model, hardware, config, and output consistency

Platforms tested:

  • AWS g5.xlarge -- NVIDIA A10G 24GB, PyTorch 2.9.1+cu130
  • GCP g2-standard-4 -- NVIDIA L4 24GB, PyTorch 2.7.1+cu128

See full validation report for detailed benchmarks and results.

Quality

  • 1,270+ tests passing across all modules
  • 0 ruff violations -- clean linting
  • 0 mypy errors -- full type coverage
  • Cloud validated on NVIDIA A10G (AWS), L4 (GCP), and AMD MI300X -- 5/5 use cases pass
  • Cross-platform tested on macOS, Linux, AWS, GCP, AMD Developer Cloud
PYTHONPATH=src python3 -m pytest tests/ -q
ruff check src/ tests/

Use Cases

Cross-vendor training -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on TPU. Same code throughout.

Cost optimization -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.

Hardware migration -- Move from one GPU vendor to another without a code rewrite.

Research portability -- Share models and training code that colleagues can run on whatever hardware they have.

Documentation

Document Description
Installation Setup and requirements
Quick Start First steps with TorchBridge
Troubleshooting Common issues and fixes
Backends Overview How the HAL works
Backend Selection Choosing the right backend
Hardware Setup Driver and toolkit installation
Distributed Training Multi-GPU and multi-node
Deployment Export, serve, containerize
CLI Reference Command-line tools
Hardware Matrix Full hardware support table
Contributing Development and contribution guide
Changelog Version history

License

See LICENSE file for licensing details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchbridge_ml-0.5.18.tar.gz (550.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchbridge_ml-0.5.18-py3-none-any.whl (659.9 kB view details)

Uploaded Python 3

File details

Details for the file torchbridge_ml-0.5.18.tar.gz.

File metadata

  • Download URL: torchbridge_ml-0.5.18.tar.gz
  • Upload date:
  • Size: 550.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchbridge_ml-0.5.18.tar.gz
Algorithm Hash digest
SHA256 2b49b1c20d55137349ee7437b124358d8c3d313f905e5eb181118fa51a78a91c
MD5 b55447ed1bbe813ae101134c38327585
BLAKE2b-256 6c70346fba2e27198fa9ddcb34feb30691e575887acef8a2567c9c2dd21958ed

See more details on using hashes here.

File details

Details for the file torchbridge_ml-0.5.18-py3-none-any.whl.

File metadata

  • Download URL: torchbridge_ml-0.5.18-py3-none-any.whl
  • Upload date:
  • Size: 659.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for torchbridge_ml-0.5.18-py3-none-any.whl
Algorithm Hash digest
SHA256 7c0848673cdd1a045b1b0cc4390eb638adf4a8663aba3c61b312696ea02a6425
MD5 b3189c0ef2c47faafa335c7935d6d3a5
BLAKE2b-256 d493abe9d64c0bb4a33fbace60051169c577c38282fa303c04ad3191fd0b15ce

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page