Skip to main content

Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU

Project description

TorchBridge

Your PyTorch code is locked to one GPU vendor. CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a cross-backend validation and configuration intelligence layer for PyTorch: it validates that outputs match across backends and generates optimal configurations for NVIDIA, AMD, Trainium, and TPU hardware.

Version Tests Cloud GPU AWS A10G GCP T4 H100 NVL MI300X TPU v5e Python PyTorch

What is TorchBridge?

PyTorch lets you build models. TorchBridge lets you run them anywhere.

Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a unified API that detects your hardware and adapts automatically.

Your model code
      |
  TorchBridge
      |
  +---------+---------+-----------+---------+
  | NVIDIA  |   AMD   | Trainium  |   TPU   |
  | CUDA    |  ROCm   |  NeuronX  |   XLA   |
  +---------+---------+-----------+---------+

What it does:

  • Backend detection -- automatically identifies available accelerators
  • Vendor adapters -- translates unified API calls to vendor-specific operations
  • Precision management -- handles FP32/FP16/BF16/FP8 across backends with compatibility matrices
  • Quantization -- backend-aware format selection with automatic fallback chains
  • Attention dispatch -- selects the best attention kernel (FlexAttention, Flash, Triton, etc.) per hardware
  • Checkpoint portability -- save on one backend, load on another with dtype normalization
  • Distributed config -- generates FSDP2, pipeline, and collective configs from detected topology

Quick Start

pip install torchbridge-ml

# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"

For development:

git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"

Detect Hardware

from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())

Run on Any Backend

import torch
from torchbridge import TorchBridgeConfig, UnifiedManager

config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)

model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)

optimized_model = manager.optimize(model)

Validate

from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")

Supported Backends

Backend Hardware Precision Status
NVIDIA B200, H100, H200, A100, L4, T4 FP4, FP8, BF16, FP16, FP32 Production
AMD MI350X, MI325X, MI300X, MI200 FP8, BF16, FP16, FP32 Production
Trainium Trn1, Trn2, Trn3 (AWS NeuronX) BF16, FP16, FP32 Supported
TPU v4, v5e, v5p, v6e, v7 BF16, FP32 Production
CPU x86, ARM (Apple Silicon) FP32, BF16 Fallback

See Hardware Matrix for full details.

Key Features

Backend Detection and Adaptation

Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.

Vendor Adapters

Each backend implements a common BaseBackend interface. Your code calls manager.optimize(model) and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.

Precision Management

Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with torch.amp autocast works across all backends.

Checkpoint Portability

Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping, dtype normalization, and FP8-to-FP16 conversion via PyTorch Distributed Checkpoint (DCP).

Distributed Configuration

Generates FSDP2 sharding strategies, pipeline schedules, and collective backend configs based on detected cluster topology. TorchBridge produces config objects that you pass to PyTorch's native distributed primitives -- it does not implement distributed training itself.

Code Examples

Backend-Agnostic Training

import torch
from torchbridge.backends import BackendFactory, detect_best_backend

backend = BackendFactory.create(detect_best_backend())
device = backend.device

model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.amp.autocast(device.type):
        loss = criterion(model(inputs), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

Hardware Capability Queries

from torchbridge.backends.nvidia import NVIDIABackend

nvidia = NVIDIABackend()
print(nvidia.get_device_info())  # GPU model, compute capability, memory
print(nvidia.supports_fp8())     # True on H100+

Cross-Backend Model Export

from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors

sample_input = torch.randn(1, 768)

export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")

Project Structure

src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   ├── trainium/      #   AWS Trainium/NeuronX backend
│   └── tpu/           #   Google TPU/XLA backend
├── hardware/          # Hardware detection and abstraction
├── precision/         # FP8 training and precision management
├── attention/         # Attention mechanisms (unified API)
├── advanced_memory/   # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/        # Model export and serving
├── monitoring/        # Metrics, logging, health checks
├── optimizations/     # Optimization patterns and strategies
├── core/              # Core config, management, optimized layers
├── cli/               # Command-line tools
├── models/            # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/        # Cross-backend validation
└── utils/             # Utilities and profiling

Cloud Hardware Validation

Cross-backend numerical consistency validated on 8 hardware platforms using Qwen3-0.6B:

Platform Hardware Max Diff Cosine Sim Latency Status
AWS NVIDIA A10G (24GB) 1.96e-05 1.000001 41.8 ms PASS
GCP NVIDIA T4 (16GB) 2.67e-05 1.000001 50.8 ms PASS
RunPod NVIDIA H100 NVL (100GB) 2.29e-05 1.000001 18.8 ms PASS
AMD DevCloud AMD MI300X (192GB) 4.82e-05 1.000001 30.0 ms PASS
GCP TPU v5e 1.08e-01 0.999980 47.5 ms PASS
Local Apple Silicon (MPS) 4.58e-05 1.000002 27.8 ms PASS
AWS Trainium Trn1.2xlarge (NeuronX) 0.00e+00 1.000001 103.3 ms (CPU) PASS
AWS Inferentia2 inf2.xlarge (NeuronX) 0.00e+00 1.000001 321.7 ms (CPU) PASS

All backends produce semantically identical outputs (cosine similarity > 0.999).

See full validation report for detailed benchmarks and results.

Quality

  • 2,392 tests collected (hardware-gated skips on non-GPU environments)
  • 0 ruff violations -- clean linting
  • 0 mypy errors -- full type coverage
  • Cloud validated on 8 hardware platforms: NVIDIA A10G (AWS), T4 (GCP), H100 NVL (RunPod), AMD MI300X, GCP TPU v5e, Apple MPS, AWS Trainium, AWS Inferentia2
  • Cross-platform tested on macOS, Linux, AWS, GCP, AMD Developer Cloud, RunPod
python3 -m pytest tests/ -q
ruff check src/ tests/

Use Cases

Cross-vendor training -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.

Cost optimization -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.

Hardware migration -- Move from one GPU vendor to another without a code rewrite.

Research portability -- Share models and training code that colleagues can run on whatever hardware they have.

Documentation

Document Description
Installation Setup and requirements
Quick Start First steps with TorchBridge
Troubleshooting Common issues and fixes
Backends Overview How the backend system works
Backend Selection Choosing the right backend
Hardware Setup Driver and toolkit installation
Distributed Training Multi-GPU and multi-node
Deployment Export, serve, containerize
CLI Reference Command-line tools
Hardware Matrix Full hardware support table
Contributing Development and contribution guide
Changelog Version history

License

See LICENSE file for licensing details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchbridge_ml-0.5.35.tar.gz (628.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchbridge_ml-0.5.35-py3-none-any.whl (761.6 kB view details)

Uploaded Python 3

File details

Details for the file torchbridge_ml-0.5.35.tar.gz.

File metadata

  • Download URL: torchbridge_ml-0.5.35.tar.gz
  • Upload date:
  • Size: 628.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.0

File hashes

Hashes for torchbridge_ml-0.5.35.tar.gz
Algorithm Hash digest
SHA256 7a7bb7623a2fde215b8e9b357ebd5f81a35a66ca5f6974e70a824278d7bea361
MD5 d481ee4d302726e5f71e10a1d6bb6f38
BLAKE2b-256 efa48c08c15be155d93d37e1138e5476f3d68f75106a59ddeac295d54ffaa436

See more details on using hashes here.

File details

Details for the file torchbridge_ml-0.5.35-py3-none-any.whl.

File metadata

File hashes

Hashes for torchbridge_ml-0.5.35-py3-none-any.whl
Algorithm Hash digest
SHA256 dd3315bd2b8984d0cef8b8baaec55e81114d78f08fd8e4cf7e66aeab7eb178d5
MD5 11292169258cb666eab6f2d8d46260ef
BLAKE2b-256 3bed92c9601edff068ae7d175eea2993e4e53edb4ad1e6134660f5fd9c49ef6a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page