Skip to main content

Sinkhorn optimal transport kernels in PyTorch + Triton (squared Euclidean, no cost matrix materialization).

Project description

FlashSinkhorn

FlashSinkhorn

PyPI Python License: MIT

Streaming Entropic Optimal Transport in PyTorch + Triton

FlashSinkhorn computes Sinkhorn OT using FlashAttention-style streaming—never materializing the n×m cost matrix—enabling O(nd) memory instead of O(n²).

Features

  • FlashSinkhorn kernels — shifted-potential formulation inspired by FlashAttention. On A100 GPUs, achieves up to 32× forward-pass and 161× end-to-end speedups over state-of-the-art online baselines on point-cloud OT
  • Fused Triton kernels for forward, gradient, and HVP
  • GeomLoss-compatible API (SamplesLoss)
  • Analytic gradients (no backprop through Sinkhorn iterations)
  • Hessian-vector products via streaming CG solver
  • Half-cost support (half_cost=True) for exact GeomLoss parity
  • Unbalanced/semi-unbalanced OT via reach parameter
  • Large-D support (d > 1024) with tiled gradient kernel
  • Early stopping with convergence threshold

Install

pip install flash-sinkhorn

# From source (development)
pip install -e ".[dev]"

Requirements: PyTorch ≥2.5, Triton ≥3.1, CUDA 12.x

Quick Start

Basic Usage

import torch
from flash_sinkhorn import SamplesLoss

x = torch.randn(4096, 64, device="cuda")
y = torch.randn(4096, 64, device="cuda")

loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)

Gradient Flow

x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")

loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)
grad_x = torch.autograd.grad(cost, x)[0]  # Analytic gradient

GeomLoss Parity

Use half_cost=True to match GeomLoss's cost convention:

# FlashSinkhorn with half_cost matches GeomLoss exactly
flash_loss = SamplesLoss(loss="sinkhorn", blur=0.1, half_cost=True, debias=True)

# Equivalent GeomLoss call
# geomloss_loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.1, debias="positive")

Unbalanced OT

For distributions with different total mass or outliers:

loss = SamplesLoss(
    loss="sinkhorn",
    blur=0.1,
    debias=True,
    reach=1.0,  # Unbalanced OT with KL penalty
)

Semi-Unbalanced OT

Different constraints for source vs target:

loss = SamplesLoss(
    loss="sinkhorn",
    blur=0.1,
    reach_x=1.0,   # Relax source marginal
    reach_y=None,  # Keep target marginal strict (balanced)
)

Early Stopping

loss = SamplesLoss(
    loss="sinkhorn",
    blur=0.1,
    n_iters=100,
    threshold=1e-3,       # Stop when potential change < threshold
    inner_iterations=10,  # Check every N iterations
)

Hessian-Vector Product

x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")
v = torch.randn_like(x)

loss = SamplesLoss(loss="sinkhorn", blur=0.1)
cost = loss(x, y)

# First-order gradient
grad_x = torch.autograd.grad(cost, x, create_graph=True)[0]

# HVP via double backward (uses streaming CG solver)
hvp = torch.autograd.grad((grad_x * v).sum(), x)[0]

FlashSinkhorn (v0.3.0)

FlashSinkhorn is a reformulated Sinkhorn kernel that uses shifted potentials inspired by FlashAttention. It reduces bias vector loads by 67% and elementwise operations by 78% per tile, and improves scalability on OT-based downstream tasks.

How It Works

Standard Sinkhorn loads 3 bias vectors per tile (g, log_b, y²). FlashSinkhorn precomputes a single fused bias u = (g_shifted + eps*log(b)) / eps and uses raw coordinates with an inline scale factor, matching FlashAttention's score interface exactly.

Performance (d=64, A100-80GB, 100 iterations)

Symmetric solver (vs v0.2.0 GeomLoss-style kernel):

n v0.2.0 v0.3.0 Speedup
50,000 1730 ms 1450 ms 1.19x
10,000 88 ms 61 ms 1.43x
5,000 25 ms 24 ms 1.04x

Alternating solver (vs v0.2.0 OTT-style kernel, 10 iterations):

n v0.2.0 v0.3.0 Speedup
50,000 137.9 ms 102.6 ms 1.34x
20,000 25.7 ms 21.7 ms 1.19x
10,000 8.9 ms 8.3 ms 1.07x

Low-Level API

Low-level FlashSinkhorn API:

from flash_sinkhorn.kernels import (
    sinkhorn_flashstyle_symmetric,     # Full symmetric solver
    sinkhorn_flashstyle_alternating,   # Full alternating solver
    flashsinkhorn_symmetric_step,      # Single fused iteration
    apply_plan_vec_flashstyle,         # Transport plan @ vector (shifted potentials)
    apply_plan_mat_flashstyle,         # Transport plan @ matrix (shifted potentials)
)

API Reference

SamplesLoss

SamplesLoss(
    loss="sinkhorn",
    p=2,                      # Only p=2 supported (squared Euclidean)
    blur=0.05,                # Regularization: eps = blur^2
    debias=True,              # Debiased Sinkhorn divergence
    half_cost=False,          # Use ||x-y||²/2 to match GeomLoss
    reach=None,               # Unbalanced OT (None = balanced)
    reach_x=None,             # Semi-unbalanced: source marginal
    reach_y=None,             # Semi-unbalanced: target marginal
    scaling=0.5,              # Epsilon annealing factor
    n_iters=None,             # Max iterations (None = use scaling)
    threshold=None,           # Early stopping threshold
    inner_iterations=10,      # Check convergence every N iters
)

Low-Level API

# FlashSinkhorn (recommended)
from flash_sinkhorn.kernels import (
    sinkhorn_flashstyle_symmetric,
    sinkhorn_flashstyle_alternating,
    apply_plan_vec_flashstyle,
    apply_plan_mat_flashstyle,
)

# Legacy kernels (still available)
from flash_sinkhorn.kernels.sinkhorn_triton_geomloss_sqeuclid import (
    sinkhorn_geomloss_online_potentials_sqeuclid,
)
from flash_sinkhorn.kernels.sinkhorn_triton_grad_sqeuclid import (
    sinkhorn_geomloss_online_grad_sqeuclid,
)
from flash_sinkhorn.hvp import hvp_x_sqeuclid_from_potentials

Key Concepts

Cost Convention

  • FlashSinkhorn default: C(x,y) = ||x-y||²
  • GeomLoss p=2 default: C(x,y) = ||x-y||²/2
  • Use half_cost=True to match GeomLoss

Memory Efficiency

FlashSinkhorn streams tiles of (x,y) and computes costs on-the-fly:

  • Forward: O(nd) memory (no n×m cost matrix)
  • Gradient: O(nd) memory (streaming accumulation)
  • HVP: O(nd) memory (CG solver with streaming matvec)

Numerical Stability

  • Uses exp2/log2 for stable LSE computation
  • Safe log/division guards against underflow
  • TF32 enabled by default for ~2x speedup on A100/H100 (set allow_tf32=False for strict FP32)
  • HVP (double backward) uses strict FP32 internally for numerical stability

Benchmarks

Compare FlashSinkhorn against GeomLoss (KeOps) and OTT-JAX.

Install benchmark dependencies:

pip install geomloss pykeops ott-jax jax[cuda12]

Run benchmarks:

# Forward pass benchmark
python -m flash_sinkhorn.bench.bench_forward --sizes 5000,10000,20000 --dims 64 --verify

# Backward pass benchmark
python -m flash_sinkhorn.bench.bench_backward --sizes 5000,10000,20000 --dims 64 --verify

# Quick test (small size)
python -m flash_sinkhorn.bench.bench_forward --sizes 5000 --dims 4 --verify

# Run only FlashSinkhorn (skip GeomLoss/OTT-JAX)
python -m flash_sinkhorn.bench.bench_forward --sizes 10000 --dims 64 --no-geomloss --no-ott

Results are saved to output/paper_benchmarks/forward/ and output/paper_benchmarks/backward/.

Citation

If you find FlashSinkhorn useful in your research, please cite our paper:

@article{ye2026flashsinkhorn,
  title={FlashSinkhorn: IO-Aware Entropic Optimal Transport},
  author={Ye, Felix X.-F. and Li, Xingjie and Yu, An and Chang, Ming-Ching and Chu, Linsong and Wertheimer, Davis},
  journal={arXiv preprint arXiv:2602.03067},
  year={2026},
  url={https://arxiv.org/abs/2602.03067}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

flash_sinkhorn-0.3.3.tar.gz (145.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

flash_sinkhorn-0.3.3-py3-none-any.whl (168.8 kB view details)

Uploaded Python 3

File details

Details for the file flash_sinkhorn-0.3.3.tar.gz.

File metadata

  • Download URL: flash_sinkhorn-0.3.3.tar.gz
  • Upload date:
  • Size: 145.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.21

File hashes

Hashes for flash_sinkhorn-0.3.3.tar.gz
Algorithm Hash digest
SHA256 6d30a687bab06795747750035075a4883905274425f564741d431ec4af7e7fea
MD5 e843d3c3b778eccc41d96674741ea3c3
BLAKE2b-256 3b2a07af134675da6e1cdc0149c0c27a8fd5e7e2cf332b74a122d33ae3f09a36

See more details on using hashes here.

File details

Details for the file flash_sinkhorn-0.3.3-py3-none-any.whl.

File metadata

  • Download URL: flash_sinkhorn-0.3.3-py3-none-any.whl
  • Upload date:
  • Size: 168.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.21

File hashes

Hashes for flash_sinkhorn-0.3.3-py3-none-any.whl
Algorithm Hash digest
SHA256 1251c702617eb3c2a5d9b0287aa5f3ddfa3fdf42bc7b6360aaea50d07ff5a2ea
MD5 71461886b465e6e3d99942369a37e6eb
BLAKE2b-256 cc338af1b15bdd57e33b1bc069fbf7f4973f00a86d6a50d516bc29b656acdb24

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page