Skip to main content

A fast, GPU-parallel, PyTorch-compatible optimal transport solver.

Project description

MDOT-TNT

MDOT-TNT Logo

A Truncated Newton Method for Optimal Transport

PyPI version Python 3.8+ License

A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.


Features

  • High Precision: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
  • GPU Accelerated: Fully compatible with CUDA for fast computation on large problems
  • Multi-GPU Support: Distribute computation across multiple GPUs via devices or num_gpus arguments for near-linear speedup
  • Batched Solving: Solve multiple OT problems simultaneously in batched mode
  • Memory Efficient: O(nk) working memory mode — never materialise the full cost or transport plan; process columns in blocks of size k
  • Point Cloud Support: Pass source/target point clouds directly with a custom cost function; achieves true O(nk + (n+m)d) memory
  • PyTorch Native: Seamless integration with PyTorch, supporting autograd-compatible inputs

Color Transfer Example

OT naturally solves the color transfer problem: find the optimal map between two color palettes and use it to recolor an image. MDOT-TNT's low-memory point-cloud solver makes this practical for full-resolution images without ever materialising an n × m cost matrix.

Source image (1.webp) Palette donor (2.webp) Result: 1 recolored with 2's palette
Source image (2.webp) Palette donor (1.webp) Result: 2 recolored with 1's palette

Run the bundled example (from the repo root):

cd color_transfer
python run.py --both-directions

Key options:

Flag Default Description
--gamma-f 1024 OT precision (higher = sharper color match)
--color-space lab Working color space (lab or rgb)
--max-pixels 8192 Max pixels sampled per image
--block-size 512 Memory/speed trade-off for the low-mem solver
--both-directions off Also produce the reverse transfer

Installation

Prerequisites: Install PyTorch for your system configuration first.

pip install mdot-tnt

For development:

git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"

Quick Start

Single Problem

import torch
import mdot_tnt

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()

# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)

# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)

Batched Solving

When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:

import torch
import mdot_tnt

device = "cuda"
batch_size, n, m = 32, 512, 512

# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)

# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024)  # Returns (batch_size,) tensor

The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.

Low-Memory / Point Cloud Solving

When n or m is large, use solve_OT_lowmem to avoid materialising the full n × m cost matrix. Cost blocks are computed on-the-fly, giving O(nk) working memory (k = block_size):

import torch
from mdot_tnt import solve_OT_lowmem

device = "cuda" if torch.cuda.is_available() else "cpu"

# Point cloud mode — no n×m matrix is ever allocated
n, m, d = 10000, 10000, 64
X = torch.rand(n, d, device=device, dtype=torch.float64)
Y = torch.rand(m, d, device=device, dtype=torch.float64)
r = torch.ones(n, device=device, dtype=torch.float64) / n
c = torch.ones(m, device=device, dtype=torch.float64) / m

cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, block_size=512)

# Dense matrix mode — full C provided, but only k columns in memory at once
C = torch.rand(n, m, device=device, dtype=torch.float64)
cost = solve_OT_lowmem(r, c, C=C, gamma_f=1024, block_size=512)

The block_size parameter controls the memory / speed trade-off:

  • block_size = m (default): fastest, equivalent to solve_OT
  • block_size = sqrt(m): good balance
  • block_size = 1: minimum memory (slowest)

Multi-GPU

All three solvers accept devices (list of GPU indices or torch.device objects) or num_gpus (int). Passing a single device falls back to the standard single-GPU path.

from mdot_tnt import solve_OT, solve_OT_batched, solve_OT_lowmem

# Single large problem — columns distributed across GPUs
cost = solve_OT(r, c, C, gamma_f=1024, devices=[0, 1, 2])

# Batched — batch split across GPUs (near-linear speedup)
costs = solve_OT_batched(r, c, C, gamma_f=1024, num_gpus=4)

# Low-memory / point cloud — column blocks distributed across GPUs
#   (3.2× speedup at n=8192, m=65536, γ=1024 on 3× RTX 2080 Ti)
cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, devices=[0, 1, 2])

API Reference

solve_OT

mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
                  devices=None, num_gpus=None)
Parameter Type Description
r Tensor Row marginal of shape (n,), must sum to 1
c Tensor Column marginal of shape (m,), must sum to 1
C Tensor Cost matrix of shape (n, m), recommended to normalize to [0, 1]
gamma_f float Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024
return_plan bool If True, return transport plan instead of cost
round bool If True, round solution onto feasible set
log bool If True, also return optimization logs
devices list GPU indices or torch.device objects to distribute across. Mutually exclusive with num_gpus
num_gpus int Number of GPUs to use (starting from index 0). Mutually exclusive with devices

Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.

solve_OT_batched

mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
                           devices=None, num_gpus=None)

Same parameters as solve_OT, but with batched inputs:

  • r: Shape (batch, n)
  • c: Shape (batch, m)
  • C: Shape (n, m) for shared cost, or (batch, n, m) for per-problem costs

With devices or num_gpus, the batch is split across GPUs and each sub-batch is solved independently.

Returns: Costs (batch,) or plans (batch, n, m).

solve_OT_lowmem

mdot_tnt.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None,
                          gamma_f=1024., block_size=None,
                          drop_tiny=False, return_plan=False, round=True, log=False,
                          devices=None, num_gpus=None)

Exactly one of C or (X, Y) must be provided.

Parameter Type Description
r Tensor Row marginal of shape (n,), must sum to 1
c Tensor Column marginal of shape (m,), must sum to 1
C Tensor Dense cost matrix (n, m). Mutually exclusive with X/Y
X Tensor Source points (n, d). Use together with Y
Y Tensor Target points (m, d). Use together with X
cost_fn callable (X, Y_block) -> C_block of shape (n, k). Defaults to squared Euclidean
gamma_f float Temperature parameter. Default: 1024
block_size int Columns per block (controls memory/speed trade-off). Default: m
drop_tiny bool Drop tiny marginal entries for speedup with sparse marginals
return_plan bool If True, return transport plan instead of cost
round bool If True, round solution onto feasible set
log bool If True, also return optimization logs
devices list GPU indices or torch.device objects to distribute column blocks across
num_gpus int Number of GPUs to use (starting from index 0). Mutually exclusive with devices

Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.

Performance Tips

  1. Use float64 for gamma_f > 1024 (automatic conversion with warning)
  2. Normalize cost matrices to [0, 1] for numerical stability
  3. Use batched solver when solving multiple problems with shared structure
  4. Increase gamma_f for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
  5. Use multi-GPU (devices / num_gpus) for large point-cloud problems or large batches — solve_OT_lowmem with column-parallel blocking achieves near-linear speedup (3.2× on 3 GPUs at n=8192, m=65536)

Citation

If you use MDOT-TNT in your research, please cite:

@inproceedings{kemertas2025truncated,
  title={A Truncated Newton Method for Optimal Transport},
  author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=gWrWUaCbMa}
}

License

This code is released under the BSD 3-Clause license..

Contact

For questions or issues, please open an issue or email: kemertas [at] cs [dot] toronto [dot] edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mdot_tnt-1.2.0.tar.gz (43.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mdot_tnt-1.2.0-py3-none-any.whl (35.0 kB view details)

Uploaded Python 3

File details

Details for the file mdot_tnt-1.2.0.tar.gz.

File metadata

  • Download URL: mdot_tnt-1.2.0.tar.gz
  • Upload date:
  • Size: 43.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mdot_tnt-1.2.0.tar.gz
Algorithm Hash digest
SHA256 86aeb41c2df5442456ca03c7860538caa6dbb9d3b55b7145c9197aa1a5120248
MD5 712bdb71e016dee31546dcfcc9030110
BLAKE2b-256 832b9686bbaf0cf7e36f501ed7c9020eadf48ea12b80546505b6109d329c9465

See more details on using hashes here.

File details

Details for the file mdot_tnt-1.2.0-py3-none-any.whl.

File metadata

  • Download URL: mdot_tnt-1.2.0-py3-none-any.whl
  • Upload date:
  • Size: 35.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mdot_tnt-1.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7c6ca132f15e82bbdf63cacb0459846baadd5586e8f981de99c6108185d21c93
MD5 206421771ac00035deb56aeae19ee4a4
BLAKE2b-256 b732ac8a3d8789bbb99d57337816a48a5a679a3cee99d55a7809fc868830b85d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page