Skip to main content

Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU

Project description

TorchBridge

TorchBridge validates that your model produces correct outputs across PyTorch backends and recommends optimal hardware configurations. It answers two questions no other tool answers in a single command:

  1. "Does my model produce correct outputs across backends?" — Run it on CUDA and ROCm and get max_diff, cosine_sim, per-layer divergence, pass/fail against empirical tolerances.
  2. "What's the optimal configuration for my model on this hardware?" — Compatibility matrices that translate (backend, architecture) → format/kernel/method with fallback chains.

Version License Tests Cloud GPU AWS A10G GCP T4 H100 NVL Python PyTorch

Quick Start

pip install torchbridge-ml

Cross-Backend Validation (the hero command)

# Compare CUDA vs ROCm outputs — max_diff, cosine_sim, pass/fail
tb-validate --compare cuda rocm --model ./model.pt

# Per-layer divergence report
tb-validate --compare cuda rocm --model ./model.pt --per-layer

# Multi-step agentic trace — track divergence amplification across 50 steps
tb-validate --compare cuda rocm --model ./model.pt --trace --steps 50 --autoregressive

# Compliance certificate + OTel span export (Langfuse, W&B, etc.)
tb-validate --compare cuda rocm --model ./model.pt --cert --otel

# CI mode — exits non-zero if max_diff exceeds tolerance
tb-validate --compare cuda rocm --model ./model.pt --ci

Hardware Configuration Advisor

# What's the optimal config for a 7B model on this hardware?
tb-advisor --model-params 7e9

# Disaggregated prefill/decode fleet config
tb-advisor --mode disaggregated --model-params 7e9 --prefill nvidia:hopper --decode amd:cdna3

# Heterogeneous cluster training config (NVIDIA + AMD mixed)
tb-advisor --mode heterogeneous --model-params 7e9 --nvidia hopper:8 --amd cdna3:4

# Doctor — diagnose your hardware setup
tb-doctor

Python API

from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
# Cross-backend validation
from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")

What TorchBridge Does

Capability What TorchBridge adds
Cross-backend validation tb-validate --compare cuda rocm — per-layer divergence, empirical tolerance DB (5 model families × 5 backends × 3 dtypes), CI-ready JSON
Multi-step agentic trace tb-validate --trace --steps 50 --autoregressive — tracks how max_diff amplifies across N autoregressive steps; reports first-divergence-step and amplification factor
Compliance certificates tb-validate --cert — SHA256-signed pass/fail certificate for KV handoff physical spec (page size, alignment, layout)
Observability integration tb-validate --otel — emits validation spans (max_diff, cosine_sim, per-layer child spans) to any OTLP endpoint (Langfuse, W&B Weave, Honeycomb)
Compatibility matrices 13 empirically-sourced matrices: (backend, architecture) → optimal quant format / attention kernel / adapter method / FSDP strategy / torch.compile mode
Config advisory tb-advisor — FSDP, quantization, KV cache, speculative decoding, disaggregated fleet (--mode disaggregated), heterogeneous clusters (--mode heterogeneous)
Backend detection Hardware identification, capability queries, priority chain across NVIDIA/AMD/Trainium/TPU/CPU
Tolerance DB 80 empirical entries, 3-level fallback, --model-family flag — tolerances sourced from real Qwen3-0.6B runs across 6 GPU platforms
CLI diagnostics tb-doctor, tb-validate, tb-advisor, tb-speculate, tb-cache, tb-adapter, tb-quantize, tb-migrate, tb-benchmark, tb-checkpoint

What TorchBridge Is NOT

  • Not a quantization library — dispatches format selection to torchao; TorchBridge adds the compatibility matrix
  • Not a serving runtime — use vLLM, TGI, or similar for production inference serving; TorchBridge validates correctness and advises configuration, it does not serve requests
  • Not a training framework — adapter math (LoRA/QLoRA) is correct and kept; use PEFT for full training workflows
  • Not a PyTorch wrapper — if a method body is return torch.something(...) with no selection logic, it doesn't belong here

Supported Backends

Backend Hardware Precision Status
NVIDIA B200, H100, H200, A100, L4, T4 FP4, FP8, BF16, FP16, FP32 Production
AMD MI350X, MI325X, MI300X, MI200 FP8, BF16, FP16, FP32 Production
Trainium Trn1, Trn2, Trn3 (AWS NeuronX) BF16, FP16, FP32 Supported
TPU v4, v5e, v5p, v6e, v7 BF16, FP32 Production
CPU x86, ARM (Apple Silicon) FP32, BF16 Fallback

See Hardware Matrix for full details.

Cloud Hardware Validation

Cross-backend numerical consistency validated on 6/8 platforms using Qwen3-0.6B (v0.5.80, re-validated 2026-03-30):

Platform Hardware Max Diff Cosine Sim Latency Status
AWS NVIDIA A10G (24GB) 2.10e-05 1.000001 40.0 ms PASS
GCP NVIDIA T4 (16GB) 2.67e-05 1.000001 50.7 ms PASS
RunPod NVIDIA H100 NVL (100GB) 1.67e-05 1.000001 16.2 ms PASS
Local Apple Silicon (MPS) 0.00e+00 1.000000 118.9 ms PASS
AWS Trainium† Trn1.2xlarge (NeuronX) 0.00e+00 1.000000 115.8 ms (CPU) PASS
AWS Inferentia2† inf2.xlarge (NeuronX) 0.00e+00 1.000000 321.8 ms (CPU) PASS
AMD DevCloud AMD MI300X (192GB) SKIPPED‡
GCP TPU v5e SKIPPED‡

CPU fallback: NeuronX SDK compilation requires quota-enabled Trn1/Inf2 instances. These rows confirm correct CPU-path behavior; accelerator validation pending. ‡ Capacity unavailable: AMD MI300X out of capacity at validation time; TPU v5e exhausted globally across 18 zones.

All tested GPU backends produce semantically identical outputs (cosine similarity > 0.999).

See full validation report for detailed benchmarks and results.

Project Structure

src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   ├── trainium/      #   AWS Trainium/NeuronX backend
│   └── tpu/           #   Google TPU/XLA backend
├── precision/         # Quantization compatibility matrix + torchao dispatch
├── attention/         # Attention kernel compatibility matrix + dispatcher
├── distributed/       # FSDP/pipeline config advisor
├── adapters/          # LoRA/QLoRA adapter injection (correct math)
├── inference/         # Speculative decoding compatibility matrix
├── checkpoint/        # DCP wrapper with cross-backend metadata
├── testing/           # DivergenceTracer, ToleranceDB, MultiStepTracer, @cross_backend
├── validation/        # UnifiedValidator — model structure, hardware, numerical stability
├── cli/               # Command-line tools (11 CLI commands)
├── models/            # LLM KV cache advisor
└── utils/             # Utilities

Quality

  • 2,187 tests passing (hardware-gated skips on non-GPU environments)
  • 0 ruff violations -- clean linting
  • 0 mypy errors -- full type coverage
  • Cloud validated on 6/8 platforms: A10G, T4, H100 NVL, MPS (GPU); Trainium, Inferentia2 (CPU-fallback†)
python3 -m pytest tests/ -q
ruff check src/ tests/

Documentation

Document Description
Installation Setup and requirements
Quick Start First steps with TorchBridge
Troubleshooting Common issues and fixes
Backends Overview How the backend system works
Backend Selection Choosing the right backend
Backend Selection Choosing backends + driver setup
Distributed Training Multi-GPU and multi-node
Testing Guide DivergenceTracer, @cross_backend, ToleranceDB
Deployment Serving and containerization
CLI Reference Command-line tools
Hardware Matrix Full hardware support table
Changelog Version history

Community

The empirical tolerance database (testing/tolerance_db.py) is only as strong as the hardware it has been measured on. Contributions that add or correct tolerance entries for hardware you have access to — AMD MI350X, Trainium2, TPU v7 Ironwood, new PyTorch versions — directly expand the validation coverage for everyone. See CONTRIBUTING.md for how to add entries and the source-label conventions ("measured", "derived", "fallback").

Versioning

v0.5.93 is the first public release. The v0.5.x series represents an extended private development and validation phase: building the backend abstraction layer, validating numerical consistency on real GPU hardware across 6 platforms, and reaching a quality bar suitable for open source. The version number reflects the maturity of the implementation, not the release count.

License

Licensed under the Apache License, Version 2.0. See LICENSE for the full text.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchbridge_ml-0.5.95.tar.gz (223.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchbridge_ml-0.5.95-py3-none-any.whl (274.9 kB view details)

Uploaded Python 3

File details

Details for the file torchbridge_ml-0.5.95.tar.gz.

File metadata

  • Download URL: torchbridge_ml-0.5.95.tar.gz
  • Upload date:
  • Size: 223.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.0

File hashes

Hashes for torchbridge_ml-0.5.95.tar.gz
Algorithm Hash digest
SHA256 1617aa0f370d70139d8a0f459bac9e5321024f07361a4e0f0186f5660a69accb
MD5 ee1548ad29ec4a2f389cd9ed2d5534ce
BLAKE2b-256 9045fe01db99f0800b74c03865d2f5166ca75c302cb5da4cb17a7ae13e5d7af3

See more details on using hashes here.

File details

Details for the file torchbridge_ml-0.5.95-py3-none-any.whl.

File metadata

File hashes

Hashes for torchbridge_ml-0.5.95-py3-none-any.whl
Algorithm Hash digest
SHA256 13e0c67f396f727b06c92f9dc42f55d82e305dc169b7905bedd41f8f81634173
MD5 cc714e724dd6be41df01287c28024d3a
BLAKE2b-256 4ddeca439f96746249f4dd9471bd345041fefa771eb8cab11445e8693248876b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page