Sinkhorn optimal transport kernels in PyTorch + Triton (squared Euclidean, no cost matrix materialization).
Project description
FlashSinkhorn
Streaming Entropic Optimal Transport in PyTorch + Triton
FlashSinkhorn computes Sinkhorn OT using FlashAttention-style streaming—never materializing the n×m cost matrix—enabling O(nd) memory instead of O(n²).
Features
- FlashSinkhorn kernels — shifted-potential formulation inspired by FlashAttention. On A100 GPUs, achieves up to 32× forward-pass and 161× end-to-end speedups over state-of-the-art online baselines on point-cloud OT
- Fused Triton kernels for forward, gradient, and HVP
- GeomLoss-compatible API (
SamplesLoss) - Analytic gradients (no backprop through Sinkhorn iterations)
- Hessian-vector products via streaming CG solver
- Half-cost support (
half_cost=True) for exact GeomLoss parity - Unbalanced/semi-unbalanced OT via
reachparameter - Large-D support (d > 1024) with tiled gradient kernel
- Early stopping with convergence threshold
Install
pip install flash-sinkhorn
# From source (development)
pip install -e ".[dev]"
Requirements: PyTorch ≥2.5, Triton ≥3.1, CUDA 12.x
Quick Start
Basic Usage
import torch
from flash_sinkhorn import SamplesLoss
x = torch.randn(4096, 64, device="cuda")
y = torch.randn(4096, 64, device="cuda")
loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)
Gradient Flow
x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")
loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)
grad_x = torch.autograd.grad(cost, x)[0] # Analytic gradient
GeomLoss Parity
Use half_cost=True to match GeomLoss's cost convention:
# FlashSinkhorn with half_cost matches GeomLoss exactly
flash_loss = SamplesLoss(loss="sinkhorn", blur=0.1, half_cost=True, debias=True)
# Equivalent GeomLoss call
# geomloss_loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.1, debias="positive")
Unbalanced OT
For distributions with different total mass or outliers:
loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
debias=True,
reach=1.0, # Unbalanced OT with KL penalty
)
Semi-Unbalanced OT
Different constraints for source vs target:
loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
reach_x=1.0, # Relax source marginal
reach_y=None, # Keep target marginal strict (balanced)
)
Early Stopping
loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
n_iters=100,
threshold=1e-3, # Stop when potential change < threshold
inner_iterations=10, # Check every N iterations
)
Hessian-Vector Product
x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")
v = torch.randn_like(x)
loss = SamplesLoss(loss="sinkhorn", blur=0.1)
cost = loss(x, y)
# First-order gradient
grad_x = torch.autograd.grad(cost, x, create_graph=True)[0]
# HVP via double backward (uses streaming CG solver)
hvp = torch.autograd.grad((grad_x * v).sum(), x)[0]
FlashSinkhorn (v0.3.0)
FlashSinkhorn is a reformulated Sinkhorn kernel that uses shifted potentials inspired by FlashAttention. It reduces bias vector loads by 67% and elementwise operations by 78% per tile, and improves scalability on OT-based downstream tasks.
How It Works
Standard Sinkhorn loads 3 bias vectors per tile (g, log_b, y²). FlashSinkhorn precomputes a single fused bias u = (g_shifted + eps*log(b)) / eps and uses raw coordinates with an inline scale factor, matching FlashAttention's score interface exactly.
Performance (d=64, A100-80GB, 100 iterations)
Symmetric solver (vs v0.2.0 GeomLoss-style kernel):
| n | v0.2.0 | v0.3.0 | Speedup |
|---|---|---|---|
| 50,000 | 1730 ms | 1450 ms | 1.19x |
| 10,000 | 88 ms | 61 ms | 1.43x |
| 5,000 | 25 ms | 24 ms | 1.04x |
Alternating solver (vs v0.2.0 OTT-style kernel, 10 iterations):
| n | v0.2.0 | v0.3.0 | Speedup |
|---|---|---|---|
| 50,000 | 137.9 ms | 102.6 ms | 1.34x |
| 20,000 | 25.7 ms | 21.7 ms | 1.19x |
| 10,000 | 8.9 ms | 8.3 ms | 1.07x |
Low-Level API
Low-level FlashSinkhorn API:
from flash_sinkhorn.kernels import (
sinkhorn_flashstyle_symmetric, # Full symmetric solver
sinkhorn_flashstyle_alternating, # Full alternating solver
flashsinkhorn_symmetric_step, # Single fused iteration
apply_plan_vec_flashstyle, # Transport plan @ vector (shifted potentials)
apply_plan_mat_flashstyle, # Transport plan @ matrix (shifted potentials)
)
API Reference
SamplesLoss
SamplesLoss(
loss="sinkhorn",
p=2, # Only p=2 supported (squared Euclidean)
blur=0.05, # Regularization: eps = blur^2
debias=True, # Debiased Sinkhorn divergence
half_cost=False, # Use ||x-y||²/2 to match GeomLoss
reach=None, # Unbalanced OT (None = balanced)
reach_x=None, # Semi-unbalanced: source marginal
reach_y=None, # Semi-unbalanced: target marginal
scaling=0.5, # Epsilon annealing factor
n_iters=None, # Max iterations (None = use scaling)
threshold=None, # Early stopping threshold
inner_iterations=10, # Check convergence every N iters
)
Low-Level API
# FlashSinkhorn (recommended)
from flash_sinkhorn.kernels import (
sinkhorn_flashstyle_symmetric,
sinkhorn_flashstyle_alternating,
apply_plan_vec_flashstyle,
apply_plan_mat_flashstyle,
)
# Legacy kernels (still available)
from flash_sinkhorn.kernels.sinkhorn_triton_geomloss_sqeuclid import (
sinkhorn_geomloss_online_potentials_sqeuclid,
)
from flash_sinkhorn.kernels.sinkhorn_triton_grad_sqeuclid import (
sinkhorn_geomloss_online_grad_sqeuclid,
)
from flash_sinkhorn.hvp import hvp_x_sqeuclid_from_potentials
Key Concepts
Cost Convention
- FlashSinkhorn default:
C(x,y) = ||x-y||² - GeomLoss p=2 default:
C(x,y) = ||x-y||²/2 - Use
half_cost=Trueto match GeomLoss
Memory Efficiency
FlashSinkhorn streams tiles of (x,y) and computes costs on-the-fly:
- Forward: O(nd) memory (no n×m cost matrix)
- Gradient: O(nd) memory (streaming accumulation)
- HVP: O(nd) memory (CG solver with streaming matvec)
Numerical Stability
- Uses
exp2/log2for stable LSE computation - Safe log/division guards against underflow
- TF32 enabled by default for ~2x speedup on A100/H100 (set
allow_tf32=Falsefor strict FP32) - HVP (double backward) uses strict FP32 internally for numerical stability
Benchmarks
Compare FlashSinkhorn against GeomLoss (KeOps) and OTT-JAX.
Install benchmark dependencies:
pip install geomloss pykeops ott-jax jax[cuda12]
Run benchmarks:
# Forward pass benchmark
python -m flash_sinkhorn.bench.bench_forward --sizes 5000,10000,20000 --dims 64 --verify
# Backward pass benchmark
python -m flash_sinkhorn.bench.bench_backward --sizes 5000,10000,20000 --dims 64 --verify
# Quick test (small size)
python -m flash_sinkhorn.bench.bench_forward --sizes 5000 --dims 4 --verify
# Run only FlashSinkhorn (skip GeomLoss/OTT-JAX)
python -m flash_sinkhorn.bench.bench_forward --sizes 10000 --dims 64 --no-geomloss --no-ott
Results are saved to output/paper_benchmarks/forward/ and output/paper_benchmarks/backward/.
Citation
If you find FlashSinkhorn useful in your research, please cite our paper:
@article{ye2026flashsinkhorn,
title={FlashSinkhorn: IO-Aware Entropic Optimal Transport},
author={Ye, Felix X.-F. and Li, Xingjie and Yu, An and Chang, Ming-Ching and Chu, Linsong and Wertheimer, Davis},
journal={arXiv preprint arXiv:2602.03067},
year={2026},
url={https://arxiv.org/abs/2602.03067}
}
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file flash_sinkhorn-0.3.3.post1.tar.gz.
File metadata
- Download URL: flash_sinkhorn-0.3.3.post1.tar.gz
- Upload date:
- Size: 145.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
08abfdbb8d72a068e8d3441c13b8e1c7a8cacf94b00b658ce8fca1fab167d535
|
|
| MD5 |
9bfe6359c0929d767e50346f8eb6c92d
|
|
| BLAKE2b-256 |
697119e9ae59007cf2a840868bb552925f3f115f8e1ed351ca5afcd2dc9e8800
|
File details
Details for the file flash_sinkhorn-0.3.3.post1-py3-none-any.whl.
File metadata
- Download URL: flash_sinkhorn-0.3.3.post1-py3-none-any.whl
- Upload date:
- Size: 168.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a2a0c40adc6c213ec9630996929cdfe57c780fecad883261da068c317041fea3
|
|
| MD5 |
46e683a5c4905382eafdec1e41c0e37e
|
|
| BLAKE2b-256 |
55283af26e0f82168761885249d9dd11dd17a3aec0f588058d5c9da03336336e
|