Skip to main content

Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU

Project description

TorchBridge

Your PyTorch code is locked to one GPU vendor. CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a cross-backend validation and configuration intelligence layer for PyTorch: it validates that outputs match across backends and generates optimal configurations for NVIDIA, AMD, Trainium, and TPU hardware.

Version Tests Cloud GPU AWS A10G GCP T4 H100 NVL MI300X TPU v5e Python PyTorch

What is TorchBridge?

PyTorch lets you build models. TorchBridge lets you run them anywhere.

Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a unified API that detects your hardware and adapts automatically.

Your model code
      |
  TorchBridge
      |
  +---------+---------+-----------+---------+
  | NVIDIA  |   AMD   | Trainium  |   TPU   |
  | CUDA    |  ROCm   |  NeuronX  |   XLA   |
  +---------+---------+-----------+---------+

What it does:

  • Backend detection -- automatically identifies available accelerators
  • Vendor adapters -- translates unified API calls to vendor-specific operations
  • Precision management -- handles FP32/FP16/BF16/FP8 across backends with compatibility matrices
  • Quantization -- backend-aware format selection with automatic fallback chains
  • Attention dispatch -- selects the best attention kernel (FlexAttention, Flash, Triton, etc.) per hardware
  • Checkpoint portability -- save on one backend, load on another with dtype normalization
  • Distributed config -- generates FSDP, pipeline, and collective configs from detected topology

Quick Start

pip install torchbridge-ml

# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"

For development:

git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"

Detect Hardware

from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())

Run on Any Backend

import torch
from torchbridge import TorchBridgeConfig, UnifiedManager

config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)

model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)

optimized_model = manager.optimize(model)

Validate

from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")

Supported Backends

Backend Hardware Precision Status
NVIDIA B200, H100, H200, A100, L4, T4 FP4, FP8, BF16, FP16, FP32 Production
AMD MI350X, MI325X, MI300X, MI200 FP8, BF16, FP16, FP32 Production
Trainium Trn1, Trn2, Trn3 (AWS NeuronX) BF16, FP16, FP32 Supported
TPU v4, v5e, v5p, v6e, v7 BF16, FP32 Production
CPU x86, ARM (Apple Silicon) FP32, BF16 Fallback

See Hardware Matrix for full details.

Key Features

Backend Detection and Adaptation

Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.

Vendor Adapters

Each backend implements a common BaseBackend interface. Your code calls manager.optimize(model) and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.

Precision Management

Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with torch.amp autocast works across all backends.

Checkpoint Portability

Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping, dtype normalization, and FP8-to-FP16 conversion via PyTorch Distributed Checkpoint (DCP).

Distributed Configuration

Generates FSDP sharding strategies, pipeline schedules, and collective backend configs based on detected cluster topology. TorchBridge produces config objects that you pass to PyTorch's native distributed primitives -- it does not implement distributed training itself.

Code Examples

Backend-Agnostic Training

import torch
from torchbridge.backends import BackendFactory, detect_best_backend

backend = BackendFactory.create(detect_best_backend())
device = backend.device

model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.amp.autocast(device.type):
        loss = criterion(model(inputs), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

Hardware Capability Queries

from torchbridge.backends.nvidia import NVIDIABackend

nvidia = NVIDIABackend()
print(nvidia.get_device_info())  # GPU model, compute capability, memory
print(nvidia.supports_fp8())     # True on H100+

Cross-Backend Model Export

from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors

sample_input = torch.randn(1, 768)

export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")

Project Structure

src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   ├── trainium/      #   AWS Trainium/NeuronX backend
│   └── tpu/           #   Google TPU/XLA backend
├── hardware/          # Hardware detection and abstraction
├── precision/         # FP8 training and precision management
├── attention/         # Attention mechanisms (unified API)
├── advanced_memory/   # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/        # Model export and serving
├── monitoring/        # Metrics, logging, health checks
├── optimizations/     # Optimization patterns and strategies
├── core/              # Core config, management, optimized layers
├── cli/               # Command-line tools
├── models/            # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/        # Cross-backend validation
└── utils/             # Utilities and profiling

Cloud Hardware Validation

Cross-backend numerical consistency validated on 8 platforms (6 real GPU/accelerator, 2 CPU-fallback†) using Qwen3-0.6B:

Platform Hardware Max Diff Cosine Sim Latency Status
AWS NVIDIA A10G (24GB) 1.96e-05 1.000001 41.8 ms PASS
GCP NVIDIA T4 (16GB) 2.67e-05 1.000001 50.8 ms PASS
RunPod NVIDIA H100 NVL (100GB) 2.29e-05 1.000001 18.8 ms PASS
AMD DevCloud AMD MI300X (192GB) 4.82e-05 1.000001 30.0 ms PASS
GCP TPU v5e 1.08e-01 0.999980 47.5 ms PASS
Local Apple Silicon (MPS) 4.58e-05 1.000002 27.8 ms PASS
AWS Trainium† Trn1.2xlarge (NeuronX) 0.00e+00 1.000001 103.3 ms (CPU) PASS
AWS Inferentia2† inf2.xlarge (NeuronX) 0.00e+00 1.000001 321.7 ms (CPU) PASS

CPU fallback: NeuronX SDK compilation requires quota-enabled Trn1/Inf2 instances not available in the validation environment. These rows confirm correct CPU-path behavior (max_diff = 0.00e+00 is CPU-vs-CPU, not accelerator validation). Real NeuronX validation is pending quota approval.

All GPU/accelerator backends produce semantically identical outputs (cosine similarity > 0.999).

See full validation report for detailed benchmarks and results.

Quality

  • 2,563 tests collected (hardware-gated skips on non-GPU environments)
  • 0 ruff violations -- clean linting
  • 0 mypy errors -- full type coverage
  • Cloud validated on 8 platforms (6 GPU-validated: A10G, T4, H100 NVL, MI300X, TPU v5e, MPS; 2 CPU-fallback†: Trainium, Inferentia2)
  • Cross-platform tested on macOS, Linux, AWS, GCP, AMD Developer Cloud, RunPod
python3 -m pytest tests/ -q
ruff check src/ tests/

Use Cases

Cross-vendor training -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.

Cost optimization -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.

Hardware migration -- Move from one GPU vendor to another without a code rewrite.

Research portability -- Share models and training code that colleagues can run on whatever hardware they have.

Documentation

Document Description
Installation Setup and requirements
Quick Start First steps with TorchBridge
Troubleshooting Common issues and fixes
Backends Overview How the backend system works
Backend Selection Choosing the right backend
Hardware Setup Driver and toolkit installation
Distributed Training Multi-GPU and multi-node
Deployment Export, serve, containerize
CLI Reference Command-line tools
Hardware Matrix Full hardware support table
Contributing Development and contribution guide
Changelog Version history

License

See LICENSE file for licensing details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchbridge_ml-0.5.44.tar.gz (639.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchbridge_ml-0.5.44-py3-none-any.whl (772.1 kB view details)

Uploaded Python 3

File details

Details for the file torchbridge_ml-0.5.44.tar.gz.

File metadata

  • Download URL: torchbridge_ml-0.5.44.tar.gz
  • Upload date:
  • Size: 639.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.0

File hashes

Hashes for torchbridge_ml-0.5.44.tar.gz
Algorithm Hash digest
SHA256 43a3a3136b1e4a5df1a5eeab48d38cb260d7fbced9c883ba32a6fb5e32787d9e
MD5 42e135329b5762a8c5baadc28b439852
BLAKE2b-256 3360594a6c1405d114da7f2941c85841dfe85fd5aca4d8eef1c0d47af54f7a9a

See more details on using hashes here.

File details

Details for the file torchbridge_ml-0.5.44-py3-none-any.whl.

File metadata

File hashes

Hashes for torchbridge_ml-0.5.44-py3-none-any.whl
Algorithm Hash digest
SHA256 8720a2defa8a30a2f79b9b86240bb7c80583d52f03684dec076bcebb2b831791
MD5 4d215a92092347aa7bb2c27eee22f060
BLAKE2b-256 c46a01806555b20847e651dbc9e2615cf5796775e09fa0fdef1f62a5cec97b10

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page