Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU
Project description
TorchBridge
Your PyTorch code is locked to one GPU vendor. CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a cross-backend validation and configuration intelligence layer for PyTorch: it validates that outputs match across backends and generates optimal configurations for NVIDIA, AMD, Trainium, and TPU hardware.
What is TorchBridge?
PyTorch lets you build models. TorchBridge lets you run them anywhere.
Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a unified API that detects your hardware and adapts automatically.
Your model code
|
TorchBridge
|
+---------+---------+-----------+---------+
| NVIDIA | AMD | Trainium | TPU |
| CUDA | ROCm | NeuronX | XLA |
+---------+---------+-----------+---------+
What it does:
- Backend detection -- automatically identifies available accelerators
- Vendor adapters -- translates unified API calls to vendor-specific operations
- Precision management -- handles FP32/FP16/BF16/FP8 across backends with compatibility matrices
- Quantization -- backend-aware format selection with automatic fallback chains
- Attention dispatch -- selects the best attention kernel (FlexAttention, Flash, Triton, etc.) per hardware
- Checkpoint portability -- save on one backend, load on another with dtype normalization
- Distributed config -- generates FSDP2, pipeline, and collective configs from detected topology
Quick Start
pip install torchbridge-ml
# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"
For development:
git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"
Detect Hardware
from torchbridge.backends import BackendFactory, detect_best_backend
backend_type = detect_best_backend() # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
Run on Any Backend
import torch
from torchbridge import TorchBridgeConfig, UnifiedManager
config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)
model = torch.nn.Sequential(
torch.nn.Linear(768, 3072),
torch.nn.GELU(),
torch.nn.Linear(3072, 768),
)
optimized_model = manager.optimize(model)
Validate
from torchbridge import UnifiedValidator
validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
Supported Backends
| Backend | Hardware | Precision | Status |
|---|---|---|---|
| NVIDIA | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| AMD | MI350X, MI325X, MI300X, MI200 | FP8, BF16, FP16, FP32 | Production |
| Trainium | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| TPU | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| CPU | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |
See Hardware Matrix for full details.
Key Features
Backend Detection and Adaptation
Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.
Vendor Adapters
Each backend implements a common BaseBackend interface. Your code calls manager.optimize(model) and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.
Precision Management
Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with torch.amp autocast works across all backends.
Checkpoint Portability
Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping, dtype normalization, and FP8-to-FP16 conversion via PyTorch Distributed Checkpoint (DCP).
Distributed Configuration
Generates FSDP2 sharding strategies, pipeline schedules, and collective backend configs based on detected cluster topology. TorchBridge produces config objects that you pass to PyTorch's native distributed primitives -- it does not implement distributed training itself.
Code Examples
Backend-Agnostic Training
import torch
from torchbridge.backends import BackendFactory, detect_best_backend
backend = BackendFactory.create(detect_best_backend())
device = backend.device
model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
with torch.amp.autocast(device.type):
loss = criterion(model(inputs), targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Hardware Capability Queries
from torchbridge.backends.nvidia import NVIDIABackend
nvidia = NVIDIABackend()
print(nvidia.get_device_info()) # GPU model, compute capability, memory
print(nvidia.supports_fp8()) # True on H100+
Cross-Backend Model Export
from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors
sample_input = torch.randn(1, 768)
export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")
Project Structure
src/torchbridge/
├── backends/ # Vendor-specific backend implementations
│ ├── nvidia/ # NVIDIA CUDA backend
│ ├── amd/ # AMD ROCm backend
│ ├── trainium/ # AWS Trainium/NeuronX backend
│ └── tpu/ # Google TPU/XLA backend
├── hardware/ # Hardware detection and abstraction
├── precision/ # FP8 training and precision management
├── attention/ # Attention mechanisms (unified API)
├── advanced_memory/ # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/ # Model export and serving
├── monitoring/ # Metrics, logging, health checks
├── optimizations/ # Optimization patterns and strategies
├── core/ # Core config, management, optimized layers
├── cli/ # Command-line tools
├── models/ # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/ # Cross-backend validation
└── utils/ # Utilities and profiling
Cloud Hardware Validation
Cross-backend numerical consistency validated on 8 hardware platforms using Qwen3-0.6B:
| Platform | Hardware | Max Diff | Cosine Sim | Latency | Status |
|---|---|---|---|---|---|
| AWS | NVIDIA A10G (24GB) | 1.96e-05 | 1.000001 | 41.8 ms | PASS |
| GCP | NVIDIA T4 (16GB) | 2.67e-05 | 1.000001 | 50.8 ms | PASS |
| RunPod | NVIDIA H100 NVL (100GB) | 2.29e-05 | 1.000001 | 18.8 ms | PASS |
| AMD DevCloud | AMD MI300X (192GB) | 4.82e-05 | 1.000001 | 30.0 ms | PASS |
| GCP | TPU v5e | 1.08e-01 | 0.999980 | 47.5 ms | PASS |
| Local | Apple Silicon (MPS) | 4.58e-05 | 1.000002 | 27.8 ms | PASS |
| AWS Trainium | Trn1.2xlarge (NeuronX) | 0.00e+00 | 1.000001 | 103.3 ms (CPU) | PASS |
| AWS Inferentia2 | inf2.xlarge (NeuronX) | 0.00e+00 | 1.000001 | 321.7 ms (CPU) | PASS |
All backends produce semantically identical outputs (cosine similarity > 0.999).
See full validation report for detailed benchmarks and results.
Quality
- 2,392 tests collected (hardware-gated skips on non-GPU environments)
- 0 ruff violations -- clean linting
- 0 mypy errors -- full type coverage
- Cloud validated on 8 hardware platforms: NVIDIA A10G (AWS), T4 (GCP), H100 NVL (RunPod), AMD MI300X, GCP TPU v5e, Apple MPS, AWS Trainium, AWS Inferentia2
- Cross-platform tested on macOS, Linux, AWS, GCP, AMD Developer Cloud, RunPod
python3 -m pytest tests/ -q
ruff check src/ tests/
Use Cases
Cross-vendor training -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.
Cost optimization -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.
Hardware migration -- Move from one GPU vendor to another without a code rewrite.
Research portability -- Share models and training code that colleagues can run on whatever hardware they have.
Documentation
| Document | Description |
|---|---|
| Installation | Setup and requirements |
| Quick Start | First steps with TorchBridge |
| Troubleshooting | Common issues and fixes |
| Backends Overview | How the backend system works |
| Backend Selection | Choosing the right backend |
| Hardware Setup | Driver and toolkit installation |
| Distributed Training | Multi-GPU and multi-node |
| Deployment | Export, serve, containerize |
| CLI Reference | Command-line tools |
| Hardware Matrix | Full hardware support table |
| Contributing | Development and contribution guide |
| Changelog | Version history |
License
See LICENSE file for licensing details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchbridge_ml-0.5.35.tar.gz.
File metadata
- Download URL: torchbridge_ml-0.5.35.tar.gz
- Upload date:
- Size: 628.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7a7bb7623a2fde215b8e9b357ebd5f81a35a66ca5f6974e70a824278d7bea361
|
|
| MD5 |
d481ee4d302726e5f71e10a1d6bb6f38
|
|
| BLAKE2b-256 |
efa48c08c15be155d93d37e1138e5476f3d68f75106a59ddeac295d54ffaa436
|
File details
Details for the file torchbridge_ml-0.5.35-py3-none-any.whl.
File metadata
- Download URL: torchbridge_ml-0.5.35-py3-none-any.whl
- Upload date:
- Size: 761.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd3315bd2b8984d0cef8b8baaec55e81114d78f08fd8e4cf7e66aeab7eb178d5
|
|
| MD5 |
11292169258cb666eab6f2d8d46260ef
|
|
| BLAKE2b-256 |
3bed92c9601edff068ae7d175eea2993e4e53edb4ad1e6134660f5fd9c49ef6a
|