Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU
Project description
TorchBridge
TorchBridge validates that your model produces correct outputs across PyTorch backends and recommends optimal hardware configurations. It answers two questions no other tool answers in a single command:
- "Does my model produce correct outputs across backends?" — Run it on CUDA and ROCm and get max_diff, cosine_sim, per-layer divergence, pass/fail against empirical tolerances.
- "What's the optimal configuration for my model on this hardware?" — Compatibility matrices that translate
(backend, architecture) → format/kernel/methodwith fallback chains.
Quick Start
pip install torchbridge-ml
Cross-Backend Validation (the hero command)
# Compare CUDA vs ROCm outputs — max_diff, cosine_sim, pass/fail
tb-validate --compare cuda rocm --model ./model.pt
# Per-layer divergence report
tb-validate --compare cuda rocm --model ./model.pt --per-layer
# Multi-step agentic trace — track divergence amplification across 50 steps
tb-validate --compare cuda rocm --model ./model.pt --trace --steps 50 --autoregressive
# Compliance certificate + OTel span export (Langfuse, W&B, etc.)
tb-validate --compare cuda rocm --model ./model.pt --cert --otel
# CI mode — exits non-zero if max_diff exceeds tolerance
tb-validate --compare cuda rocm --model ./model.pt --ci
Hardware Configuration Advisor
# What's the optimal config for this hardware?
tb-advisor
# Disaggregated prefill/decode fleet config
tb-advisor --mode disaggregated --prefill-backend nvidia --decode-backend amd
# Heterogeneous cluster training config (NVIDIA + AMD mixed)
tb-advisor --mode heterogeneous --nvidia hopper:8 --amd cdna3:4
# Doctor — diagnose your hardware setup
tb-doctor
Python API
from torchbridge.backends import BackendFactory, detect_best_backend
backend_type = detect_best_backend() # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
# Cross-backend validation
from torchbridge import UnifiedValidator
validator = UnifiedValidator()
results = validator.validate_model(model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
What TorchBridge Does
| Capability | What TorchBridge adds |
|---|---|
| Cross-backend validation | tb-validate --compare cuda rocm — per-layer divergence, empirical tolerance DB (5 model families × 5 backends × 3 dtypes), CI-ready JSON |
| Multi-step agentic trace | tb-validate --trace --steps 50 --autoregressive — tracks how max_diff amplifies across N autoregressive steps; reports first-divergence-step and amplification factor |
| Compliance certificates | tb-validate --cert — SHA256-signed pass/fail certificate for KV handoff physical spec (page size, alignment, layout) |
| Observability integration | tb-validate --otel — emits validation spans (max_diff, cosine_sim, per-layer child spans) to any OTLP endpoint (Langfuse, W&B Weave, Honeycomb) |
| Compatibility matrices | 12 empirically-sourced matrices: (backend, architecture) → optimal quant format / attention kernel / adapter method / FSDP strategy |
| Config advisory | tb-advisor — FSDP, quantization, KV cache, speculative decoding, disaggregated fleet (--mode disaggregated), heterogeneous clusters (--mode heterogeneous) |
| Backend detection | Hardware identification, capability queries, priority chain across NVIDIA/AMD/Trainium/TPU/CPU |
| Tolerance DB | 80 empirical entries, 3-level fallback, --model-family flag — tolerances sourced from real Qwen3-0.6B runs across 6 GPU platforms |
| CLI diagnostics | tb-doctor, tb-validate, tb-advisor, tb-quantize, tb-migrate, tb-benchmark |
What TorchBridge Is NOT
- Not a quantization library — dispatches format selection to torchao; TorchBridge adds the compatibility matrix
- Not a serving runtime — the inference server is a validation demo, not a production serving replacement for vLLM or TGI
- Not a training framework — adapter math (LoRA/QLoRA) is correct and kept; use PEFT for full training workflows
- Not a PyTorch wrapper — if a method body is
return torch.something(...)with no selection logic, it doesn't belong here
Supported Backends
| Backend | Hardware | Precision | Status |
|---|---|---|---|
| NVIDIA | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| AMD | MI350X, MI325X, MI300X, MI200 | FP8, BF16, FP16, FP32 | Production |
| Trainium | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| TPU | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| CPU | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |
See Hardware Matrix for full details.
Cloud Hardware Validation
Cross-backend numerical consistency validated on 8 platforms (6 real GPU/accelerator, 2 CPU-fallback†) using Qwen3-0.6B:
| Platform | Hardware | Max Diff | Cosine Sim | Latency | Status |
|---|---|---|---|---|---|
| AWS | NVIDIA A10G (24GB) | 1.96e-05 | 1.000001 | 41.8 ms | PASS |
| GCP | NVIDIA T4 (16GB) | 2.67e-05 | 1.000001 | 50.8 ms | PASS |
| RunPod | NVIDIA H100 NVL (100GB) | 2.29e-05 | 1.000001 | 18.8 ms | PASS |
| AMD DevCloud | AMD MI300X (192GB) | 4.82e-05 | 1.000001 | 30.0 ms | PASS |
| GCP | TPU v5e | 1.08e-01 | 0.999980 | 47.5 ms | PASS |
| Local | Apple Silicon (MPS) | 4.58e-05 | 1.000002 | 27.8 ms | PASS |
| AWS Trainium† | Trn1.2xlarge (NeuronX) | 0.00e+00 | 1.000001 | 103.3 ms (CPU) | PASS |
| AWS Inferentia2† | inf2.xlarge (NeuronX) | 0.00e+00 | 1.000001 | 321.7 ms (CPU) | PASS |
† CPU fallback: NeuronX SDK compilation requires quota-enabled Trn1/Inf2 instances not available in the validation environment. These rows confirm correct CPU-path behavior. Real NeuronX accelerator validation is pending quota approval.
All GPU/accelerator backends produce semantically identical outputs (cosine similarity > 0.999).
See full validation report for detailed benchmarks and results.
Project Structure
src/torchbridge/
├── backends/ # Vendor-specific backend implementations
│ ├── nvidia/ # NVIDIA CUDA backend
│ ├── amd/ # AMD ROCm backend
│ ├── trainium/ # AWS Trainium/NeuronX backend
│ └── tpu/ # Google TPU/XLA backend
├── precision/ # Quantization compatibility matrix + torchao dispatch
├── attention/ # Attention kernel compatibility matrix + dispatcher
├── distributed/ # FSDP/pipeline config advisor
├── adapters/ # LoRA/QLoRA adapter injection (correct math)
├── inference/ # Speculative decoding compatibility matrix
├── checkpoint/ # DCP wrapper with cross-backend metadata
├── testing/ # DivergenceTracer, ToleranceDB, MultiStepTracer, @cross_backend
├── validation/ # UnifiedValidator — model structure, hardware, numerical stability
├── cli/ # Command-line tools (13 entry points)
├── models/ # LLM KV cache advisor
└── utils/ # Utilities
Quality
- 2,223 tests passing (hardware-gated skips on non-GPU environments)
- 0 ruff violations -- clean linting
- 0 mypy errors -- full type coverage
- Cloud validated on 8 platforms (6 GPU-validated: A10G, T4, H100 NVL, MI300X, TPU v5e, MPS; 2 CPU-fallback†: Trainium, Inferentia2)
python3 -m pytest tests/ -q
ruff check src/ tests/
Documentation
| Document | Description |
|---|---|
| Installation | Setup and requirements |
| Quick Start | First steps with TorchBridge |
| Troubleshooting | Common issues and fixes |
| Backends Overview | How the backend system works |
| Backend Selection | Choosing the right backend |
| Hardware Setup | Driver and toolkit installation |
| Distributed Training | Multi-GPU and multi-node |
| Deployment | Serving and containerization |
| CLI Reference | Command-line tools |
| Hardware Matrix | Full hardware support table |
| Changelog | Version history |
Community
The empirical tolerance database (testing/tolerance_db.py) is only as strong as the hardware it has been measured on. Contributions that add or correct tolerance entries for hardware you have access to — AMD MI350X, Trainium2, TPU v7 Ironwood, new PyTorch versions — directly expand the validation coverage for everyone. See CONTRIBUTING.md for how to add entries and the source-label conventions ("measured", "derived", "fallback").
License
Licensed under the Apache License, Version 2.0. See LICENSE for the full text.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchbridge_ml-0.5.80.tar.gz.
File metadata
- Download URL: torchbridge_ml-0.5.80.tar.gz
- Upload date:
- Size: 223.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
086c363c29c9380227dd70f97d74ae2fce224c0d9f2c88552ecf7e38012704a3
|
|
| MD5 |
5da6a9e643e2f487fd759356863f971f
|
|
| BLAKE2b-256 |
e75f845c1cfd1c37406da4c91b2c80fe111885f630fbdd4a2e70421b6bb813a7
|
File details
Details for the file torchbridge_ml-0.5.80-py3-none-any.whl.
File metadata
- Download URL: torchbridge_ml-0.5.80-py3-none-any.whl
- Upload date:
- Size: 274.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
25e639ddcc2c3e33e9efe68c852ccf419ae5048f26716667cab088da24c9f837
|
|
| MD5 |
cdeb94e8f73f8c923099f6619e75a0e7
|
|
| BLAKE2b-256 |
409122004b24c5679e281f3c0326a37b0d210eda7abfb25c4ff142dfce180521
|