Fast Sampled Gromov-Wasserstein optimal transport solver — pure PyTorch, scalable, differentiable
Project description
Fast Sampled Gromov-Wasserstein Optimal Transport
Pure PyTorch | Triton GPU Kernels | Differentiable | Up to 175x faster than POT
TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.
Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.
Highlights
|
Performance
|
Features
|
News
v0.4.1 (2026-04-09) -- Exact differentiable gradients via implicit differentiation. The previous "envelope theorem" backward was a frozen-potentials approximation with up to 30x gradient error; now replaced by an adjoint system solved via Schur complement on the Sinkhorn Jacobian. New
grad_modeparameter ("implicit"default,"unrolled"alternative). See algorithm docs for the math.v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.
Installation
pip install -e .
Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib.
Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.
Quick Start
from torchgw import sampled_gw
# Basic usage
T = sampled_gw(X_source, X_target)
# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
Minimal working example (click to expand)
import torch
from torchgw import sampled_gw
X = torch.randn(500, 3) # source: 500 points in 3D
Y = torch.randn(600, 5) # target: 600 points in 5D (dimensions may differ)
T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")
Benchmark
Spiral (2D) to Swiss roll (3D) alignment, mixed_precision=True, landmark distances:
NVIDIA H100 80GB HBM3:
| Scale | Time | Spearman rho | GPU Memory |
|---|---|---|---|
| 4,000 x 5,000 | 0.8 s | 0.999 | 0.7 GB |
| 10,000 x 12,000 | 4.1 s | 0.999 | 3.9 GB |
| 20,000 x 25,000 | 4.6 s | 0.999 | 16 GB |
| 30,000 x 35,000 | 9.3 s | 0.999 | 34 GB |
| 40,000 x 50,000 | 17 s | 0.999 | 64 GB |
| 45,000 x 45,000 | 18 s | 0.999 | 65 GB |
NVIDIA L40S 48GB
| Scale | Time | Spearman rho | GPU Memory |
|---|---|---|---|
| 4,000 x 5,000 | 2.4 s | 0.999 | 1.1 GB |
| 10,000 x 12,000 | 3.0 s | 0.999 | 6.7 GB |
| 20,000 x 25,000 | 12 s | 0.999 | 18 GB |
| 30,000 x 35,000 | 25 s | 0.999 | 34 GB |
| 35,000 x 40,000 | 34 s | 0.999 | 45 GB |
Alignment quality (Spearman >= 0.999) is maintained across all scales. At 4000x5000, TorchGW is ~175x faster than POT (1.0s vs 183s). Max scale is bounded by GPU memory for the N*K transport plan (~80% VRAM utilization).
Reproduce
python examples/benchmark_scale.py
Benchmark plots
| 400 vs 500 | 4000 vs 5000 |
|---|---|
Distance Modes
Choose based on your data scale:
| Mode | Best for | Per-iteration | Memory | Notes |
|---|---|---|---|---|
"precomputed" |
N < 5k | O(NM) lookup | O(N^2) | All-pairs Dijkstra upfront |
"dijkstra" |
5k-50k | O(MN log N) | O(NM) | On-the-fly with caching |
"landmark" |
any scale | O(NMd) GPU | O(Nd) | Recommended default |
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")
# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")
# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)
Usage Guide
Best performance settings
T = sampled_gw(
X, Y,
distance_mode="landmark", # avoids expensive all-pairs Dijkstra
mixed_precision=True, # float32 Sinkhorn (2x faster on GPU)
M=80, # more samples = better cost estimate
epsilon=0.005, # moderate regularization
)
Fused Gromov-Wasserstein
Blend structural (graph distance) and feature (linear) costs:
C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)
# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)
Semi-relaxed transport
For unbalanced datasets (e.g., cell types present in one sample but not the other):
T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penalty
Multi-scale warm start
Speeds up convergence by solving a coarse problem first:
T = sampled_gw(X, Y, multiscale=True, n_coarse=200)
Note: GW has symmetric local optima. Works best on data without strong symmetries.
Differentiable mode
Use GW transport as a differentiable layer (exact gradients via implicit differentiation):
C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward() # exact gradients flow to encoder parameters
# For memory-constrained settings, unrolled autograd is also available:
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")
Low-rank Sinkhorn (N, K > 50k)
For very large problems where the N*K transport plan does not fit in memory:
from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)
Memory: O((N+K)*rank) instead of O(NK).
API
sampled_gw
sampled_gw(
X_source, X_target, # (N, D) and (K, D') feature matrices
*,
distance_mode="dijkstra", # "precomputed" | "dijkstra" | "landmark"
fgw_alpha=0.0, # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
C_linear=None, # (N, K) feature cost matrix for FGW
M=50, # anchor pairs per iteration
epsilon=0.001, # entropic regularization
max_iter=500, tol=1e-5, # convergence control
mixed_precision=False, # float32 Sinkhorn for GPU speed
semi_relaxed=False, # relax target marginal
differentiable=False, # keep autograd graph
multiscale=False, # coarse-to-fine warm start
log=False, # return (T, info_dict)
... # see docs for full parameter list
) -> Tensor # (N, K) transport plan
sampled_lowrank_gw
Same interface plus rank, lr_max_iter, lr_dykstra_max_iter.
Uses Scetbon, Cuturi & Peyre (2021) factorization.
When to use: only when N*K exceeds GPU memory. At smaller scales,
sampled_gwis faster.
Full API documentation: chansigit.github.io/torchgw
How It Works
The idea
Gromov-Wasserstein finds a coupling between two point clouds by comparing distances within each space rather than distances across spaces. This means the two inputs can live in completely different dimensions -- a 2D spiral can be aligned to a 3D Swiss roll, or a gene expression matrix to a chromatin accessibility matrix.
Standard GW solvers compute the full NN and KK pairwise distance matrices and an O(NK(N+K)) cost tensor at each step, which is prohibitive beyond a few thousand points. TorchGW replaces this with a stochastic approximation: sample M anchor pairs from the current transport plan, compute distances only for those anchors, and build a low-variance cost estimate in O(NKM) -- making each iteration orders of magnitude cheaper.
Algorithm
flowchart TB
subgraph inputs [" "]
direction LR
X["Source X\n(N points, D dims)"]
Y["Target Y\n(K points, D' dims)"]
end
X --> G1["Build kNN graph"]
Y --> G2["Build kNN graph"]
G1 --> loop
G2 --> loop
subgraph loop ["GW Main Loop — repeat until converged"]
S["1. Sample M anchor pairs (i,j)\nfrom current T\n(GPU multinomial)"]
D["2. Compute graph distances\nfrom anchors\nD_left (N×M), D_tgt (K×M)"]
C["3. Assemble GW cost matrix (N×K)\nΛ = mean(D²_left) − 2/M · D_left·D_tgt' + mean(D²_tgt)"]
K["4. Sinkhorn projection → T_new\n(Triton fused kernels, log-domain)"]
M["5. Momentum blend\nT ← (1−α)T + α·T_new\n+ warm-start potentials"]
CV["6. Converged?\n(cost plateau detection)"]
S --> D --> C --> K --> M --> CV
CV -- "no" --> S
end
CV -- "yes" --> T["T* (N × K)\noptimal transport plan"]
style inputs fill:none,stroke:none
style loop fill:#f8f9fa,stroke:#dee2e6,stroke-width:2px
style T fill:#d4edda,stroke:#28a745,stroke-width:2px
Why it's fast
| Technique | What it does | Speedup |
|---|---|---|
| Sampled cost | O(NKM) instead of O(NK(N+K)) per iteration | 10-100x |
| Triton Sinkhorn | Fused GPU kernels: single-pass logsumexp, no intermediate N*K allocations | 2-5x |
| Warm-start | Reuse Sinkhorn potentials (log u, log v) across GW iterations | 2-3x fewer Sinkhorn steps |
| Mixed precision | float32 Sinkhorn in log domain (numerically safe), float64 output | up to 2x on consumer GPUs |
| Dijkstra cache | Cache per-node shortest paths, FIFO eviction | avoids redundant graph traversals |
| Cost plateau detection | Stop when GW cost EMA plateaus, not when noisy ‖T-T_prev‖ < tol | saves 50-80% of max_iter |
See the algorithm documentation for the full mathematical formulation, including the semi-relaxed extension and differentiable gradient computation.
Development
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v # 72 tests, ~18s
Citation
If you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.1},
year = {2026},
}
License
This project is source-available.
It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.
Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.
Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University.
For commercial licensing inquiries, contact Stanford Office of Technology Licensing: otl@stanford.edu
See COMMERCIAL_LICENSE.md for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchgw-0.4.2.tar.gz.
File metadata
- Download URL: torchgw-0.4.2.tar.gz
- Upload date:
- Size: 47.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
882a8f7344a42bf5bbabaee5b8951eeeb627ed711fdd4df2fccc57484429a7a6
|
|
| MD5 |
b068a21ee4c97c304197338485cff219
|
|
| BLAKE2b-256 |
82f78a1da4dfb3ba8fe98c8b8e40fbfce7656837d89dc7a8be2172dbc889ef4d
|
File details
Details for the file torchgw-0.4.2-py3-none-any.whl.
File metadata
- Download URL: torchgw-0.4.2-py3-none-any.whl
- Upload date:
- Size: 30.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2c5ce20930a3b4f94addce7b746919f7bc03f7322a7b1244fb55102540044221
|
|
| MD5 |
bdc6bfe93fa494034c0909807fb136db
|
|
| BLAKE2b-256 |
dd88c9484aa337324f8e325cb59972a25f16d12b014ffd9cf9a4818c42e6a7db
|