Hardware abstraction layer for PyTorch across NVIDIA, AMD, Trainium, and TPU backends
Project description
TorchBridge
Your PyTorch code is locked to one GPU vendor. CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a hardware abstraction layer that makes your models run on NVIDIA, AMD, Trainium, and TPU without code changes, and validates that outputs match across backends.
What is TorchBridge?
PyTorch lets you build models. TorchBridge lets you run them anywhere.
Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a unified API that detects your hardware and adapts automatically.
Your model code
|
TorchBridge HAL
|
+---------+---------+-----------+---------+
| NVIDIA | AMD | Trainium | TPU |
| CUDA | ROCm | NeuronX | XLA |
+---------+---------+-----------+---------+
What it does:
- Backend detection -- automatically identifies available accelerators
- Vendor adapters -- translates unified API calls to vendor-specific operations
- Precision management -- handles FP32/FP16/BF16/FP8 across backends
- Memory optimization -- gradient checkpointing, activation offloading, memory pooling
- Checkpoint portability -- save on one backend, load on another
- Distributed training -- tensor/pipeline/data parallelism across backend types
Quick Start
pip install torchbridge-ml
# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"
For development:
git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"
Detect Hardware
from torchbridge.backends import BackendFactory, detect_best_backend
backend_type = detect_best_backend() # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
Run on Any Backend
import torch
from torchbridge import TorchBridgeConfig, UnifiedManager
config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)
model = torch.nn.Sequential(
torch.nn.Linear(768, 3072),
torch.nn.GELU(),
torch.nn.Linear(3072, 768),
)
optimized_model = manager.optimize(model)
Validate
from torchbridge import UnifiedValidator
validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
Supported Backends
| Backend | Hardware | Precision | Status |
|---|---|---|---|
| NVIDIA | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| AMD | MI350X, MI325X, MI300X, MI200 | FP4, FP8, BF16, FP16, FP32 | Production |
| Trainium | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| TPU | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| CPU | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |
See Hardware Matrix for full details.
Key Features
Backend Detection and Adaptation
Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.
Vendor Adapters
Each backend implements a common BaseBackend interface. Your code calls manager.optimize(model) and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.
Precision Management
Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with torch.amp autocast works across all backends.
Memory Optimization
Gradient checkpointing, activation offloading, optimizer state sharding, and memory pooling. These work consistently whether you're on a single GPU or a multi-node cluster.
Checkpoint Portability
Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping and dtype conversion.
Distributed Training
Tensor parallelism, pipeline parallelism, and FSDP with a unified API. The same distributed training script runs on NVIDIA DGX, AMD Instinct, Trainium instances, or TPU pods.
Code Examples
Backend-Agnostic Training
import torch
from torchbridge.backends import BackendFactory, detect_best_backend
backend = BackendFactory.create(detect_best_backend())
device = backend.device
model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
inputs, targets = inputs.to(device), targets.to(device)
with torch.amp.autocast(device.type):
loss = criterion(model(inputs), targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Hardware Capability Queries
from torchbridge.backends.nvidia import NVIDIABackend
nvidia = NVIDIABackend()
print(nvidia.get_device_info()) # GPU model, compute capability, memory
print(nvidia.supports_fp8()) # True on H100+
Cross-Backend Model Export
from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors
sample_input = torch.randn(1, 768)
export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")
Project Structure
src/torchbridge/
├── backends/ # Vendor-specific backend implementations
│ ├── nvidia/ # NVIDIA CUDA backend
│ ├── amd/ # AMD ROCm backend
│ ├── trainium/ # AWS Trainium/NeuronX backend
│ └── tpu/ # Google TPU/XLA backend
├── hardware/ # Hardware detection and abstraction
├── precision/ # FP8 training and precision management
├── attention/ # Attention mechanisms (unified API)
├── advanced_memory/ # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/ # Model export and serving
├── monitoring/ # Metrics, logging, health checks
├── optimizations/ # Optimization patterns and strategies
├── core/ # Core config, management, optimized layers
├── cli/ # Command-line tools
├── models/ # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/ # Cross-backend validation
└── utils/ # Utilities and profiling
Cloud GPU Validation
All 5 use cases validated on real GPU hardware across AWS and GCP:
| Use Case | AWS A10G | GCP L4 | Description |
|---|---|---|---|
| Export Pipeline | PASS | PASS | TorchScript, ONNX, SafeTensors export with validation |
| LLM Optimization | PASS | PASS | Qwen3/DeepSeek optimization with backend-specific tuning |
| CI/CD Validation | PASS | PASS | Diagnostics, benchmarks, cross-backend checks |
| Backend Training | PASS | PASS | AMP training with auto backend detection |
| Cross-Backend Validation | PASS | PASS | Model, hardware, config, and output consistency |
Platforms tested:
- AWS g5.xlarge -- NVIDIA A10G 24GB, PyTorch 2.9.1+cu130
- GCP n1-standard-4 -- NVIDIA T4 16GB, PyTorch 2.7.1+cu128
See full validation report for detailed benchmarks and results.
Quality
- 1,464 tests passing across all modules
- 0 ruff violations -- clean linting
- 0 mypy errors -- full type coverage
- Cloud validated on NVIDIA A10G (AWS), L4 (GCP), and AMD MI300X -- 5/5 use cases pass
- Cross-platform tested on macOS, Linux, AWS, GCP, AMD Developer Cloud
python3 -m pytest tests/ -q
ruff check src/ tests/
Use Cases
Cross-vendor training -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.
Cost optimization -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.
Hardware migration -- Move from one GPU vendor to another without a code rewrite.
Research portability -- Share models and training code that colleagues can run on whatever hardware they have.
Documentation
| Document | Description |
|---|---|
| Installation | Setup and requirements |
| Quick Start | First steps with TorchBridge |
| Troubleshooting | Common issues and fixes |
| Backends Overview | How the HAL works |
| Backend Selection | Choosing the right backend |
| Hardware Setup | Driver and toolkit installation |
| Distributed Training | Multi-GPU and multi-node |
| Deployment | Export, serve, containerize |
| CLI Reference | Command-line tools |
| Hardware Matrix | Full hardware support table |
| Contributing | Development and contribution guide |
| Changelog | Version history |
License
See LICENSE file for licensing details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchbridge_ml-0.5.22.tar.gz.
File metadata
- Download URL: torchbridge_ml-0.5.22.tar.gz
- Upload date:
- Size: 555.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a0d0a2eea12e2102f4c548c75fac6629f11f9326b4e65e3f82b695897609ad2c
|
|
| MD5 |
23ca1e1906d3cc81d329f64d23025ac5
|
|
| BLAKE2b-256 |
0f1fb50a40d2c580125a4cae7fe4eeaf88e4b239fe430a859139fa5a82b2d0b2
|
File details
Details for the file torchbridge_ml-0.5.22-py3-none-any.whl.
File metadata
- Download URL: torchbridge_ml-0.5.22-py3-none-any.whl
- Upload date:
- Size: 665.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
838a5a265a9af03dc70e9dd034a3507aebb2c87b28829cc3ed955ed3c99cc6e3
|
|
| MD5 |
6f87a6df2af8328cb984b12c7adfddbe
|
|
| BLAKE2b-256 |
8f7566e3140615910d693772647e1d4615863e50419e377c5507b4c37eb447c1
|