Fast Sampled Gromov-Wasserstein optimal transport solver — pure PyTorch, scalable, differentiable
Project description
Fast Sampled Gromov-Wasserstein Optimal Transport
Pure PyTorch | Triton GPU Kernels | Differentiable | Up to 175x faster than POT
TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.
Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.
Highlights
|
Performance
|
Features
|
News
v0.4.1 (2026-04-09) -- Exact differentiable gradients via implicit differentiation. The previous "envelope theorem" backward was a frozen-potentials approximation with up to 30x gradient error; now replaced by an adjoint system solved via Schur complement on the Sinkhorn Jacobian. New
grad_modeparameter ("implicit"default,"unrolled"alternative). See algorithm docs for the math.v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.
Installation
pip install -e .
Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib.
Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.
Quick Start
from torchgw import sampled_gw
# Basic usage
T = sampled_gw(X_source, X_target)
# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
Minimal working example (click to expand)
import torch
from torchgw import sampled_gw
X = torch.randn(500, 3) # source: 500 points in 3D
Y = torch.randn(600, 5) # target: 600 points in 5D (dimensions may differ)
T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")
Benchmark
Spiral (2D) to Swiss roll (3D) alignment on NVIDIA L40S:
| Scale | Method | Time | Spearman rho |
|---|---|---|---|
| 400 vs 500 | POT ot.gromov_wasserstein |
1.6 s | 0.999 |
| 400 vs 500 | TorchGW | 0.46 s | 0.999 |
| 4000 vs 5000 | POT ot.gromov_wasserstein |
183 s | 0.999 |
| 4000 vs 5000 | TorchGW precomputed | 5.1 s | 0.998 |
| 4000 vs 5000 | TorchGW landmark | 1.0 s | 0.999 |
At 4000x5000 with landmark distances, TorchGW is up to ~175x faster than POT with equal quality.
Benchmark plots
| 400 vs 500 | 4000 vs 5000 |
|---|---|
Distance Modes
Choose based on your data scale:
| Mode | Best for | Per-iteration | Memory | Notes |
|---|---|---|---|---|
"precomputed" |
N < 5k | O(NM) lookup | O(N^2) | All-pairs Dijkstra upfront |
"dijkstra" |
5k-50k | O(MN log N) | O(NM) | On-the-fly with caching |
"landmark" |
any scale | O(NMd) GPU | O(Nd) | Recommended default |
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")
# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")
# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)
Usage Guide
Best performance settings
T = sampled_gw(
X, Y,
distance_mode="landmark", # avoids expensive all-pairs Dijkstra
mixed_precision=True, # float32 Sinkhorn (2x faster on GPU)
M=80, # more samples = better cost estimate
epsilon=0.005, # moderate regularization
)
Fused Gromov-Wasserstein
Blend structural (graph distance) and feature (linear) costs:
C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)
# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)
Semi-relaxed transport
For unbalanced datasets (e.g., cell types present in one sample but not the other):
T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penalty
Multi-scale warm start
Speeds up convergence by solving a coarse problem first:
T = sampled_gw(X, Y, multiscale=True, n_coarse=200)
Note: GW has symmetric local optima. Works best on data without strong symmetries.
Differentiable mode
Use GW transport as a differentiable layer (exact gradients via implicit differentiation):
C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward() # exact gradients flow to encoder parameters
# For memory-constrained settings, unrolled autograd is also available:
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")
Low-rank Sinkhorn (N, K > 50k)
For very large problems where the N*K transport plan does not fit in memory:
from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)
Memory: O((N+K)*rank) instead of O(NK).
API
sampled_gw
sampled_gw(
X_source, X_target, # (N, D) and (K, D') feature matrices
*,
distance_mode="dijkstra", # "precomputed" | "dijkstra" | "landmark"
fgw_alpha=0.0, # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
C_linear=None, # (N, K) feature cost matrix for FGW
M=50, # anchor pairs per iteration
epsilon=0.001, # entropic regularization
max_iter=500, tol=1e-5, # convergence control
mixed_precision=False, # float32 Sinkhorn for GPU speed
semi_relaxed=False, # relax target marginal
differentiable=False, # keep autograd graph
multiscale=False, # coarse-to-fine warm start
log=False, # return (T, info_dict)
... # see docs for full parameter list
) -> Tensor # (N, K) transport plan
sampled_lowrank_gw
Same interface plus rank, lr_max_iter, lr_dykstra_max_iter.
Uses Scetbon, Cuturi & Peyre (2021) factorization.
When to use: only when N*K exceeds GPU memory. At smaller scales,
sampled_gwis faster.
Full API documentation: chansigit.github.io/torchgw
How It Works
┌─────────────────────────────────────────────┐
│ GW Main Loop │
│ │
T_init ──────────►│ 1. GPU multinomial sampling (M anchors) │
│ 2. Distance computation (Dijkstra/landmark)│
│ 3. GW cost matrix assembly │
│ 4. Triton fused Sinkhorn projection │
│ 5. Momentum update + warm-start │
│ 6. Cost plateau convergence check │
│ │
│ ↺ repeat until converged │
└──────────────────────┬──────────────────────┘
│
▼
T* (N × K)
Acceleration stack:
- Triton kernels -- fused row/column logsumexp, fused T materialization, fused marginal check
- Warm-start -- reuse Sinkhorn potentials across iterations
- Mixed precision -- float32 log-domain + float64 output
- Dijkstra cache -- avoid redundant SSSP across iterations
- Cost plateau early stopping -- stop when converged, not at max_iter
Development
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v # 72 tests, ~18s
Citation
If you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.1},
year = {2026},
}
License
This project is source-available.
It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.
Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.
Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University.
For commercial licensing inquiries, contact Stanford Office of Technology Licensing: otl@stanford.edu
See COMMERCIAL_LICENSE.md for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchgw-0.4.1.tar.gz.
File metadata
- Download URL: torchgw-0.4.1.tar.gz
- Upload date:
- Size: 44.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c30ea124508351b94114ae14481122d8727b63b8e8f62f71cda6acf58fbaec5a
|
|
| MD5 |
4e56aa553ae0938ec0624a3aba37a0fd
|
|
| BLAKE2b-256 |
2647fdf980896c566d9629e21af0d0766ecb30e81b518f5f035463eb3c6a1b22
|
File details
Details for the file torchgw-0.4.1-py3-none-any.whl.
File metadata
- Download URL: torchgw-0.4.1-py3-none-any.whl
- Upload date:
- Size: 29.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
15d428139fe8311e8de3535215d9534cd97f07471faf73df05aa2f66f7373f92
|
|
| MD5 |
bfa89457bcd8bfa7c919ac15fff6dae8
|
|
| BLAKE2b-256 |
e085e3bf2b9f3a5d64f5dc29d87309da5657487ab64a73d8198d662f1834eaf6
|