Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU
Project description
TorchBridge
TorchBridge validates that your model produces correct outputs across PyTorch backends and recommends optimal hardware configurations. It answers two questions no other tool answers in a single command:
- "Does my model produce correct outputs across backends?" — Run it on CUDA and ROCm and get max_diff, cosine_sim, per-layer divergence, pass/fail against empirical tolerances.
- "What's the optimal configuration for my model on this hardware?" — Compatibility matrices that translate
(backend, architecture) → format/kernel/methodwith fallback chains.
Quick Start
pip install torchbridge-ml
Cross-Backend Validation (the hero command)
# Compare CUDA vs ROCm outputs — max_diff, cosine_sim, pass/fail
tb-validate --compare cuda rocm --model ./model.pt
# Per-layer divergence report
tb-validate --compare cuda rocm --model ./model.pt --per-layer
# Multi-step agentic trace — track divergence amplification across 50 steps
tb-validate --compare cuda rocm --model ./model.pt --trace --steps 50 --autoregressive
# Compliance certificate + OTel span export (Langfuse, W&B, etc.)
tb-validate --compare cuda rocm --model ./model.pt --cert --otel
# CI mode — exits non-zero if max_diff exceeds tolerance
tb-validate --compare cuda rocm --model ./model.pt --ci
Hardware Configuration Advisor
# What's the optimal config for a 7B model on this hardware?
tb-advisor --model-params 7e9
# Disaggregated prefill/decode fleet config
tb-advisor --mode disaggregated --model-params 7e9 --prefill nvidia:hopper --decode amd:cdna3
# Heterogeneous cluster training config (NVIDIA + AMD mixed)
tb-advisor --mode heterogeneous --model-params 7e9 --nvidia hopper:8 --amd cdna3:4
# Doctor — diagnose your hardware setup
tb-doctor
Python API
from torchbridge.backends import BackendFactory, detect_best_backend
backend_type = detect_best_backend() # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
# Cross-backend validation
from torchbridge import UnifiedValidator
validator = UnifiedValidator()
results = validator.validate_model(model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
What TorchBridge Does
| Capability | What TorchBridge adds |
|---|---|
| Cross-backend validation | tb-validate --compare cuda rocm — per-layer divergence, empirical tolerance DB (5 model families × 5 backends × 3 dtypes), CI-ready JSON |
| Multi-step agentic trace | tb-validate --trace --steps 50 --autoregressive — tracks how max_diff amplifies across N autoregressive steps; reports first-divergence-step and amplification factor |
| Compliance certificates | tb-validate --cert — SHA256-signed pass/fail certificate for KV handoff physical spec (page size, alignment, layout) |
| Observability integration | tb-validate --otel — emits validation spans (max_diff, cosine_sim, per-layer child spans) to any OTLP endpoint (Langfuse, W&B Weave, Honeycomb) |
| Compatibility matrices | 13 empirically-sourced matrices: (backend, architecture) → optimal quant format / attention kernel / adapter method / FSDP strategy / torch.compile mode |
| Config advisory | tb-advisor — FSDP, quantization, KV cache, speculative decoding, disaggregated fleet (--mode disaggregated), heterogeneous clusters (--mode heterogeneous) |
| Backend detection | Hardware identification, capability queries, priority chain across NVIDIA/AMD/Trainium/TPU/CPU |
| Tolerance DB | 80 empirical entries, 3-level fallback, --model-family flag — tolerances sourced from real Qwen3-0.6B runs across 6 GPU platforms |
| CLI diagnostics | tb-doctor, tb-validate, tb-advisor, tb-speculate, tb-cache, tb-adapter, tb-quantize, tb-migrate, tb-benchmark, tb-checkpoint |
What TorchBridge Is NOT
- Not a quantization library — dispatches format selection to torchao; TorchBridge adds the compatibility matrix
- Not a serving runtime — use vLLM, TGI, or similar for production inference serving; TorchBridge validates correctness and advises configuration, it does not serve requests
- Not a training framework — adapter math (LoRA/QLoRA) is correct and kept; use PEFT for full training workflows
- Not a PyTorch wrapper — if a method body is
return torch.something(...)with no selection logic, it doesn't belong here
Supported Backends
| Backend | Hardware | Precision | Status |
|---|---|---|---|
| NVIDIA | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| AMD | MI350X, MI325X, MI300X, MI200 | FP8, BF16, FP16, FP32 | Production |
| Trainium | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| TPU | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| CPU | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |
See Hardware Matrix for full details.
Cloud Hardware Validation
Cross-backend numerical consistency validated on 6/8 platforms using Qwen3-0.6B (v0.5.80, re-validated 2026-03-30):
| Platform | Hardware | Max Diff | Cosine Sim | Latency | Status |
|---|---|---|---|---|---|
| AWS | NVIDIA A10G (24GB) | 2.10e-05 | 1.000001 | 40.0 ms | PASS |
| GCP | NVIDIA T4 (16GB) | 2.67e-05 | 1.000001 | 50.7 ms | PASS |
| RunPod | NVIDIA H100 NVL (100GB) | 1.67e-05 | 1.000001 | 16.2 ms | PASS |
| Local | Apple Silicon (MPS) | 0.00e+00 | 1.000000 | 118.9 ms | PASS |
| AWS Trainium† | Trn1.2xlarge (NeuronX) | 0.00e+00 | 1.000000 | 115.8 ms (CPU) | PASS |
| AWS Inferentia2† | inf2.xlarge (NeuronX) | 0.00e+00 | 1.000000 | 321.8 ms (CPU) | PASS |
| AMD DevCloud | AMD MI300X (192GB) | — | — | — | SKIPPED‡ |
| GCP | TPU v5e | — | — | — | SKIPPED‡ |
† CPU fallback: NeuronX SDK compilation requires quota-enabled Trn1/Inf2 instances. These rows confirm correct CPU-path behavior; accelerator validation pending. ‡ Capacity unavailable: AMD MI300X out of capacity at validation time; TPU v5e exhausted globally across 18 zones.
All tested GPU backends produce semantically identical outputs (cosine similarity > 0.999).
See full validation report for detailed benchmarks and results.
Project Structure
src/torchbridge/
├── backends/ # Vendor-specific backend implementations
│ ├── nvidia/ # NVIDIA CUDA backend
│ ├── amd/ # AMD ROCm backend
│ ├── trainium/ # AWS Trainium/NeuronX backend
│ └── tpu/ # Google TPU/XLA backend
├── precision/ # Quantization compatibility matrix + torchao dispatch
├── attention/ # Attention kernel compatibility matrix + dispatcher
├── distributed/ # FSDP/pipeline config advisor
├── adapters/ # LoRA/QLoRA adapter injection (correct math)
├── inference/ # Speculative decoding compatibility matrix
├── checkpoint/ # DCP wrapper with cross-backend metadata
├── testing/ # DivergenceTracer, ToleranceDB, MultiStepTracer, @cross_backend
├── validation/ # UnifiedValidator — model structure, hardware, numerical stability
├── cli/ # Command-line tools (11 CLI commands)
├── models/ # LLM KV cache advisor
└── utils/ # Utilities
Quality
- 2,187 tests passing (hardware-gated skips on non-GPU environments)
- 0 ruff violations -- clean linting
- 0 mypy errors -- full type coverage
- Cloud validated on 6/8 platforms: A10G, T4, H100 NVL, MPS (GPU); Trainium, Inferentia2 (CPU-fallback†)
python3 -m pytest tests/ -q
ruff check src/ tests/
Documentation
| Document | Description |
|---|---|
| Installation | Setup and requirements |
| Quick Start | First steps with TorchBridge |
| Troubleshooting | Common issues and fixes |
| Backends Overview | How the backend system works |
| Backend Selection | Choosing the right backend |
| Backend Selection | Choosing backends + driver setup |
| Distributed Training | Multi-GPU and multi-node |
| Testing Guide | DivergenceTracer, @cross_backend, ToleranceDB |
| Deployment | Serving and containerization |
| CLI Reference | Command-line tools |
| Hardware Matrix | Full hardware support table |
| Changelog | Version history |
Community
The empirical tolerance database (testing/tolerance_db.py) is only as strong as the hardware it has been measured on. Contributions that add or correct tolerance entries for hardware you have access to — AMD MI350X, Trainium2, TPU v7 Ironwood, new PyTorch versions — directly expand the validation coverage for everyone. See CONTRIBUTING.md for how to add entries and the source-label conventions ("measured", "derived", "fallback").
Versioning
v0.5.93 is the first public release. The v0.5.x series represents an extended private development and validation phase: building the backend abstraction layer, validating numerical consistency on real GPU hardware across 6 platforms, and reaching a quality bar suitable for open source. The version number reflects the maturity of the implementation, not the release count.
License
Licensed under the Apache License, Version 2.0. See LICENSE for the full text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchbridge_ml-0.5.95.tar.gz.
File metadata
- Download URL: torchbridge_ml-0.5.95.tar.gz
- Upload date:
- Size: 223.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1617aa0f370d70139d8a0f459bac9e5321024f07361a4e0f0186f5660a69accb
|
|
| MD5 |
ee1548ad29ec4a2f389cd9ed2d5534ce
|
|
| BLAKE2b-256 |
9045fe01db99f0800b74c03865d2f5166ca75c302cb5da4cb17a7ae13e5d7af3
|
File details
Details for the file torchbridge_ml-0.5.95-py3-none-any.whl.
File metadata
- Download URL: torchbridge_ml-0.5.95-py3-none-any.whl
- Upload date:
- Size: 274.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
13e0c67f396f727b06c92f9dc42f55d82e305dc169b7905bedd41f8f81634173
|
|
| MD5 |
cc714e724dd6be41df01287c28024d3a
|
|
| BLAKE2b-256 |
4ddeca439f96746249f4dd9471bd345041fefa771eb8cab11445e8693248876b
|