Skip to main content

Fast Sampled Gromov-Wasserstein optimal transport solver — pure PyTorch, scalable, differentiable

Project description

TorchGW logo

Fast Sampled Gromov-Wasserstein Optimal Transport

Documentation GitHub Version License Python PyTorch

Pure PyTorch  |  Triton GPU Kernels  |  Differentiable  |  Up to 175x faster than POT


TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.

Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.

Highlights

Performance

  • Up to 175x faster than POT on typical workloads
  • Triton fused Sinkhorn -- single-pass logsumexp, zero N*K intermediates
  • Mixed precision: float32 Sinkhorn + float64 output
  • Smart early stopping via cost plateau detection

Features

  • Pure GW, Fused GW, and semi-relaxed transport
  • Three distance modes: precomputed, Dijkstra, landmark
  • Differentiable transport plans (autograd support)
  • Low-rank Sinkhorn for N, K > 50k
  • Multi-scale coarse-to-fine warm start

News

v0.4.1 (2026-04-09) -- Exact differentiable gradients via implicit differentiation. The previous "envelope theorem" backward was a frozen-potentials approximation with up to 30x gradient error; now replaced by an adjoint system solved via Schur complement on the Sinkhorn Jacobian. New grad_mode parameter ("implicit" default, "unrolled" alternative). See algorithm docs for the math.

v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.


Installation

pip install -e .

Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib. Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.

Quick Start

from torchgw import sampled_gw

# Basic usage
T = sampled_gw(X_source, X_target)

# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
Minimal working example (click to expand)
import torch
from torchgw import sampled_gw

X = torch.randn(500, 3)   # source: 500 points in 3D
Y = torch.randn(600, 5)   # target: 600 points in 5D (dimensions may differ)

T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")

Benchmark

Spiral (2D) to Swiss roll (3D) alignment, mixed_precision=True, landmark distances:

NVIDIA H100 80GB HBM3:

Scale Time Spearman rho GPU Memory
4,000 x 5,000 0.8 s 0.999 0.7 GB
10,000 x 12,000 4.1 s 0.999 3.9 GB
20,000 x 25,000 4.6 s 0.999 16 GB
30,000 x 35,000 9.3 s 0.999 34 GB
40,000 x 50,000 17 s 0.999 64 GB
45,000 x 45,000 18 s 0.999 65 GB
NVIDIA L40S 48GB
Scale Time Spearman rho GPU Memory
4,000 x 5,000 2.4 s 0.999 1.1 GB
10,000 x 12,000 3.0 s 0.999 6.7 GB
20,000 x 25,000 12 s 0.999 18 GB
30,000 x 35,000 25 s 0.999 34 GB
35,000 x 40,000 34 s 0.999 45 GB

Alignment quality (Spearman >= 0.999) is maintained across all scales. At 4000x5000, TorchGW is ~175x faster than POT (1.0s vs 183s). Max scale is bounded by GPU memory for the N*K transport plan (~80% VRAM utilization).

Reproduce
python examples/benchmark_scale.py
Benchmark plots
400 vs 500 4000 vs 5000
400v500 4000v5000

Distance Modes

Choose based on your data scale:

Mode Best for Per-iteration Memory Notes
"precomputed" N < 5k O(NM) lookup O(N^2) All-pairs Dijkstra upfront
"dijkstra" 5k-50k O(MN log N) O(NM) On-the-fly with caching
"landmark" any scale O(NMd) GPU O(Nd) Recommended default
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")

# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")

# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)

Usage Guide

Best performance settings
T = sampled_gw(
    X, Y,
    distance_mode="landmark",   # avoids expensive all-pairs Dijkstra
    mixed_precision=True,       # float32 Sinkhorn (2x faster on GPU)
    M=80,                       # more samples = better cost estimate
    epsilon=0.005,              # moderate regularization
)
Fused Gromov-Wasserstein

Blend structural (graph distance) and feature (linear) costs:

C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)

# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)
Semi-relaxed transport

For unbalanced datasets (e.g., cell types present in one sample but not the other):

T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penalty
Multi-scale warm start

Speeds up convergence by solving a coarse problem first:

T = sampled_gw(X, Y, multiscale=True, n_coarse=200)

Note: GW has symmetric local optima. Works best on data without strong symmetries.

Differentiable mode

Use GW transport as a differentiable layer (exact gradients via implicit differentiation):

C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward()  # exact gradients flow to encoder parameters

# For memory-constrained settings, unrolled autograd is also available:
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")
Low-rank Sinkhorn (N, K > 50k)

For very large problems where the N*K transport plan does not fit in memory:

from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)

Memory: O((N+K)*rank) instead of O(NK).


API

sampled_gw

sampled_gw(
    X_source, X_target,         # (N, D) and (K, D') feature matrices
    *,
    distance_mode="dijkstra",   # "precomputed" | "dijkstra" | "landmark"
    fgw_alpha=0.0,              # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
    C_linear=None,              # (N, K) feature cost matrix for FGW
    M=50,                       # anchor pairs per iteration
    epsilon=0.001,              # entropic regularization
    max_iter=500, tol=1e-5,     # convergence control
    mixed_precision=False,      # float32 Sinkhorn for GPU speed
    semi_relaxed=False,         # relax target marginal
    differentiable=False,       # keep autograd graph
    multiscale=False,           # coarse-to-fine warm start
    log=False,                  # return (T, info_dict)
    ...                         # see docs for full parameter list
) -> Tensor                     # (N, K) transport plan

sampled_lowrank_gw

Same interface plus rank, lr_max_iter, lr_dykstra_max_iter. Uses Scetbon, Cuturi & Peyre (2021) factorization.

When to use: only when N*K exceeds GPU memory. At smaller scales, sampled_gw is faster.

Full API documentation: chansigit.github.io/torchgw


How It Works

The idea

Gromov-Wasserstein finds a coupling between two point clouds by comparing distances within each space rather than distances across spaces. This means the two inputs can live in completely different dimensions -- a 2D spiral can be aligned to a 3D Swiss roll, or a gene expression matrix to a chromatin accessibility matrix.

Standard GW solvers compute the full NN and KK pairwise distance matrices and an O(NK(N+K)) cost tensor at each step, which is prohibitive beyond a few thousand points. TorchGW replaces this with a stochastic approximation: sample M anchor pairs from the current transport plan, compute distances only for those anchors, and build a low-variance cost estimate in O(NKM) -- making each iteration orders of magnitude cheaper.

Algorithm

flowchart TB
    subgraph inputs [" "]
        direction LR
        X["Source X\n(N points, D dims)"]
        Y["Target Y\n(K points, D' dims)"]
    end

    X --> G1["Build kNN graph"]
    Y --> G2["Build kNN graph"]

    G1 --> loop
    G2 --> loop

    subgraph loop ["GW Main Loop — repeat until converged"]
        S["1. Sample M anchor pairs (i,j)\nfrom current T\n(GPU multinomial)"]
        D["2. Compute graph distances\nfrom anchors\nD_left (N×M), D_tgt (K×M)"]
        C["3. Assemble GW cost matrix (N×K)\nΛ = mean(D²_left) − 2/M · D_left·D_tgt' + mean(D²_tgt)"]
        K["4. Sinkhorn projection → T_new\n(Triton fused kernels, log-domain)"]
        M["5. Momentum blend\nT ← (1−α)T + α·T_new\n+ warm-start potentials"]
        CV["6. Converged?\n(cost plateau detection)"]

        S --> D --> C --> K --> M --> CV
        CV -- "no" --> S
    end

    CV -- "yes" --> T["T* (N × K)\noptimal transport plan"]

    style inputs fill:none,stroke:none
    style loop fill:#f8f9fa,stroke:#dee2e6,stroke-width:2px
    style T fill:#d4edda,stroke:#28a745,stroke-width:2px

Why it's fast

Technique What it does Speedup
Sampled cost O(NKM) instead of O(NK(N+K)) per iteration 10-100x
Triton Sinkhorn Fused GPU kernels: single-pass logsumexp, no intermediate N*K allocations 2-5x
Warm-start Reuse Sinkhorn potentials (log u, log v) across GW iterations 2-3x fewer Sinkhorn steps
Mixed precision float32 Sinkhorn in log domain (numerically safe), float64 output up to 2x on consumer GPUs
Dijkstra cache Cache per-node shortest paths, FIFO eviction avoids redundant graph traversals
Cost plateau detection Stop when GW cost EMA plateaus, not when noisy ‖T-T_prev‖ < tol saves 50-80% of max_iter

See the algorithm documentation for the full mathematical formulation, including the semi-relaxed extension and differentiable gradient computation.


Development

git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v          # 72 tests, ~18s

Citation

If you use TorchGW in your research, please cite:

@software{torchgw,
  author = {Sijie Chen},
  title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
  url = {https://github.com/chansigit/torchgw},
  version = {0.4.1},
  year = {2026},
}

License

This project is source-available.

It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.

Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.

Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University.

For commercial licensing inquiries, contact Stanford Office of Technology Licensing: otl@stanford.edu

See COMMERCIAL_LICENSE.md for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgw-0.4.2.tar.gz (47.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchgw-0.4.2-py3-none-any.whl (30.9 kB view details)

Uploaded Python 3

File details

Details for the file torchgw-0.4.2.tar.gz.

File metadata

  • Download URL: torchgw-0.4.2.tar.gz
  • Upload date:
  • Size: 47.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for torchgw-0.4.2.tar.gz
Algorithm Hash digest
SHA256 882a8f7344a42bf5bbabaee5b8951eeeb627ed711fdd4df2fccc57484429a7a6
MD5 b068a21ee4c97c304197338485cff219
BLAKE2b-256 82f78a1da4dfb3ba8fe98c8b8e40fbfce7656837d89dc7a8be2172dbc889ef4d

See more details on using hashes here.

File details

Details for the file torchgw-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: torchgw-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 30.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for torchgw-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2c5ce20930a3b4f94addce7b746919f7bc03f7322a7b1244fb55102540044221
MD5 bdc6bfe93fa494034c0909807fb136db
BLAKE2b-256 dd88c9484aa337324f8e325cb59972a25f16d12b014ffd9cf9a4818c42e6a7db

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page