Skip to main content

Fast Sampled Gromov-Wasserstein optimal transport solver — pure PyTorch, scalable, differentiable

Project description

TorchGW logo

Fast Sampled Gromov-Wasserstein Optimal Transport

Documentation GitHub Version License Python PyTorch

Pure PyTorch  |  Triton GPU Kernels  |  Differentiable  |  Up to 175x faster than POT


TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.

Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.

Highlights

Performance

  • Up to 175x faster than POT on typical workloads
  • Triton fused Sinkhorn -- single-pass logsumexp, zero N*K intermediates
  • Mixed precision: float32 Sinkhorn + float64 output
  • Smart early stopping via cost plateau detection

Features

  • Pure GW, Fused GW, and semi-relaxed transport
  • Three distance modes: precomputed, Dijkstra, landmark
  • Differentiable transport plans (autograd support)
  • Low-rank Sinkhorn for N, K > 50k
  • Multi-scale coarse-to-fine warm start

News

v0.4.1 (2026-04-09) -- Exact differentiable gradients via implicit differentiation. The previous "envelope theorem" backward was a frozen-potentials approximation with up to 30x gradient error; now replaced by an adjoint system solved via Schur complement on the Sinkhorn Jacobian. New grad_mode parameter ("implicit" default, "unrolled" alternative). See algorithm docs for the math.

v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.


Installation

pip install -e .

Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib. Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.

Quick Start

from torchgw import sampled_gw

# Basic usage
T = sampled_gw(X_source, X_target)

# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)
Minimal working example (click to expand)
import torch
from torchgw import sampled_gw

X = torch.randn(500, 3)   # source: 500 points in 3D
Y = torch.randn(600, 5)   # target: 600 points in 5D (dimensions may differ)

T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")

Benchmark

Spiral (2D) to Swiss roll (3D) alignment on NVIDIA L40S:

Scale Method Time Spearman rho
400 vs 500 POT ot.gromov_wasserstein 1.6 s 0.999
400 vs 500 TorchGW 0.46 s 0.999
4000 vs 5000 POT ot.gromov_wasserstein 183 s 0.999
4000 vs 5000 TorchGW precomputed 5.1 s 0.998
4000 vs 5000 TorchGW landmark 1.0 s 0.999

At 4000x5000 with landmark distances, TorchGW is up to ~175x faster than POT with equal quality.

Benchmark plots
400 vs 500 4000 vs 5000
400v500 4000v5000

Distance Modes

Choose based on your data scale:

Mode Best for Per-iteration Memory Notes
"precomputed" N < 5k O(NM) lookup O(N^2) All-pairs Dijkstra upfront
"dijkstra" 5k-50k O(MN log N) O(NM) On-the-fly with caching
"landmark" any scale O(NMd) GPU O(Nd) Recommended default
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")

# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")

# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)

Usage Guide

Best performance settings
T = sampled_gw(
    X, Y,
    distance_mode="landmark",   # avoids expensive all-pairs Dijkstra
    mixed_precision=True,       # float32 Sinkhorn (2x faster on GPU)
    M=80,                       # more samples = better cost estimate
    epsilon=0.005,              # moderate regularization
)
Fused Gromov-Wasserstein

Blend structural (graph distance) and feature (linear) costs:

C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)

# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)
Semi-relaxed transport

For unbalanced datasets (e.g., cell types present in one sample but not the other):

T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penalty
Multi-scale warm start

Speeds up convergence by solving a coarse problem first:

T = sampled_gw(X, Y, multiscale=True, n_coarse=200)

Note: GW has symmetric local optima. Works best on data without strong symmetries.

Differentiable mode

Use GW transport as a differentiable layer (exact gradients via implicit differentiation):

C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward()  # exact gradients flow to encoder parameters

# For memory-constrained settings, unrolled autograd is also available:
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")
Low-rank Sinkhorn (N, K > 50k)

For very large problems where the N*K transport plan does not fit in memory:

from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)

Memory: O((N+K)*rank) instead of O(NK).


API

sampled_gw

sampled_gw(
    X_source, X_target,         # (N, D) and (K, D') feature matrices
    *,
    distance_mode="dijkstra",   # "precomputed" | "dijkstra" | "landmark"
    fgw_alpha=0.0,              # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
    C_linear=None,              # (N, K) feature cost matrix for FGW
    M=50,                       # anchor pairs per iteration
    epsilon=0.001,              # entropic regularization
    max_iter=500, tol=1e-5,     # convergence control
    mixed_precision=False,      # float32 Sinkhorn for GPU speed
    semi_relaxed=False,         # relax target marginal
    differentiable=False,       # keep autograd graph
    multiscale=False,           # coarse-to-fine warm start
    log=False,                  # return (T, info_dict)
    ...                         # see docs for full parameter list
) -> Tensor                     # (N, K) transport plan

sampled_lowrank_gw

Same interface plus rank, lr_max_iter, lr_dykstra_max_iter. Uses Scetbon, Cuturi & Peyre (2021) factorization.

When to use: only when N*K exceeds GPU memory. At smaller scales, sampled_gw is faster.

Full API documentation: chansigit.github.io/torchgw


How It Works

                    ┌─────────────────────────────────────────────┐
                    │              GW Main Loop                   │
                    │                                             │
  T_init ──────────►│  1. GPU multinomial sampling (M anchors)    │
                    │  2. Distance computation (Dijkstra/landmark)│
                    │  3. GW cost matrix assembly                 │
                    │  4. Triton fused Sinkhorn projection        │
                    │  5. Momentum update + warm-start            │
                    │  6. Cost plateau convergence check          │
                    │                                             │
                    │  ↺ repeat until converged                   │
                    └──────────────────────┬──────────────────────┘
                                           │
                                           ▼
                                     T* (N × K)

Acceleration stack:

  • Triton kernels -- fused row/column logsumexp, fused T materialization, fused marginal check
  • Warm-start -- reuse Sinkhorn potentials across iterations
  • Mixed precision -- float32 log-domain + float64 output
  • Dijkstra cache -- avoid redundant SSSP across iterations
  • Cost plateau early stopping -- stop when converged, not at max_iter

Development

git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v          # 72 tests, ~18s

Citation

If you use TorchGW in your research, please cite:

@software{torchgw,
  author = {Sijie Chen},
  title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
  url = {https://github.com/chansigit/torchgw},
  version = {0.4.1},
  year = {2026},
}

License

This project is source-available.

It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.

Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.

Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University.

For commercial licensing inquiries, contact Stanford Office of Technology Licensing: otl@stanford.edu

See COMMERCIAL_LICENSE.md for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgw-0.4.1.tar.gz (44.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchgw-0.4.1-py3-none-any.whl (29.4 kB view details)

Uploaded Python 3

File details

Details for the file torchgw-0.4.1.tar.gz.

File metadata

  • Download URL: torchgw-0.4.1.tar.gz
  • Upload date:
  • Size: 44.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for torchgw-0.4.1.tar.gz
Algorithm Hash digest
SHA256 c30ea124508351b94114ae14481122d8727b63b8e8f62f71cda6acf58fbaec5a
MD5 4e56aa553ae0938ec0624a3aba37a0fd
BLAKE2b-256 2647fdf980896c566d9629e21af0d0766ecb30e81b518f5f035463eb3c6a1b22

See more details on using hashes here.

File details

Details for the file torchgw-0.4.1-py3-none-any.whl.

File metadata

  • Download URL: torchgw-0.4.1-py3-none-any.whl
  • Upload date:
  • Size: 29.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.1

File hashes

Hashes for torchgw-0.4.1-py3-none-any.whl
Algorithm Hash digest
SHA256 15d428139fe8311e8de3535215d9534cd97f07471faf73df05aa2f66f7373f92
MD5 bfa89457bcd8bfa7c919ac15fff6dae8
BLAKE2b-256 e085e3bf2b9f3a5d64f5dc29d87309da5657487ab64a73d8198d662f1834eaf6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page