A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
Project description
MDOT-TNT
A Truncated Newton Method for Optimal Transport
A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
Features
- High Precision: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
- GPU Accelerated: Fully compatible with CUDA for fast computation on large problems
- Multi-GPU Support: Distribute computation across multiple GPUs via
devicesornum_gpusarguments for near-linear speedup - Batched Solving: Solve multiple OT problems simultaneously in batched mode
- Memory Efficient: O(nk) working memory mode — never materialise the full cost or transport plan; process columns in blocks of size k
- Point Cloud Support: Pass source/target point clouds directly with a custom cost function; achieves true O(nk + (n+m)d) memory
- PyTorch Native: Seamless integration with PyTorch, supporting autograd-compatible inputs
Color Transfer Example
OT naturally solves the color transfer problem: find the optimal map between two color palettes and use it to recolor an image. MDOT-TNT's low-memory point-cloud solver makes this practical for full-resolution images without ever materialising an n × m cost matrix.
Source image (1.webp) |
Palette donor (2.webp) |
Result: 1 recolored with 2's palette |
|---|---|---|
Source image (2.webp) |
Palette donor (1.webp) |
Result: 2 recolored with 1's palette |
Run the bundled example (from the repo root):
cd color_transfer
python run.py --both-directions
Key options:
| Flag | Default | Description |
|---|---|---|
--gamma-f |
1024 |
OT precision (higher = sharper color match) |
--color-space |
lab |
Working color space (lab or rgb) |
--max-pixels |
8192 |
Max pixels sampled per image |
--block-size |
512 |
Memory/speed trade-off for the low-mem solver |
--both-directions |
off | Also produce the reverse transfer |
Installation
Prerequisites: Install PyTorch for your system configuration first.
pip install mdot-tnt
For development:
git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"
Quick Start
Single Problem
import torch
import mdot_tnt
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()
# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
Batched Solving
When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
import torch
import mdot_tnt
device = "cuda"
batch_size, n, m = 32, 512, 512
# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)
# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
Low-Memory / Point Cloud Solving
When n or m is large, use solve_OT_lowmem to avoid materialising the full n × m cost matrix. Cost blocks are computed on-the-fly, giving O(nk) working memory (k = block_size):
import torch
from mdot_tnt import solve_OT_lowmem
device = "cuda" if torch.cuda.is_available() else "cpu"
# Point cloud mode — no n×m matrix is ever allocated
n, m, d = 10000, 10000, 64
X = torch.rand(n, d, device=device, dtype=torch.float64)
Y = torch.rand(m, d, device=device, dtype=torch.float64)
r = torch.ones(n, device=device, dtype=torch.float64) / n
c = torch.ones(m, device=device, dtype=torch.float64) / m
cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, block_size=512)
# Dense matrix mode — full C provided, but only k columns in memory at once
C = torch.rand(n, m, device=device, dtype=torch.float64)
cost = solve_OT_lowmem(r, c, C=C, gamma_f=1024, block_size=512)
The block_size parameter controls the memory / speed trade-off:
block_size = m(default): fastest, equivalent tosolve_OTblock_size = sqrt(m): good balanceblock_size = 1: minimum memory (slowest)
Multi-GPU
All three solvers accept devices (list of GPU indices or torch.device objects) or num_gpus (int). Passing a single device falls back to the standard single-GPU path.
from mdot_tnt import solve_OT, solve_OT_batched, solve_OT_lowmem
# Single large problem — columns distributed across GPUs
cost = solve_OT(r, c, C, gamma_f=1024, devices=[0, 1, 2])
# Batched — batch split across GPUs (near-linear speedup)
costs = solve_OT_batched(r, c, C, gamma_f=1024, num_gpus=4)
# Low-memory / point cloud — column blocks distributed across GPUs
# (3.2× speedup at n=8192, m=65536, γ=1024 on 3× RTX 2080 Ti)
cost = solve_OT_lowmem(r, c, X=X, Y=Y, gamma_f=1024, devices=[0, 1, 2])
API Reference
solve_OT
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
devices=None, num_gpus=None)
| Parameter | Type | Description |
|---|---|---|
r |
Tensor |
Row marginal of shape (n,), must sum to 1 |
c |
Tensor |
Column marginal of shape (m,), must sum to 1 |
C |
Tensor |
Cost matrix of shape (n, m), recommended to normalize to [0, 1] |
gamma_f |
float |
Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
return_plan |
bool |
If True, return transport plan instead of cost |
round |
bool |
If True, round solution onto feasible set |
log |
bool |
If True, also return optimization logs |
devices |
list |
GPU indices or torch.device objects to distribute across. Mutually exclusive with num_gpus |
num_gpus |
int |
Number of GPUs to use (starting from index 0). Mutually exclusive with devices |
Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.
solve_OT_batched
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False,
devices=None, num_gpus=None)
Same parameters as solve_OT, but with batched inputs:
r: Shape(batch, n)c: Shape(batch, m)C: Shape(n, m)for shared cost, or(batch, n, m)for per-problem costs
With devices or num_gpus, the batch is split across GPUs and each sub-batch is solved independently.
Returns: Costs (batch,) or plans (batch, n, m).
solve_OT_lowmem
mdot_tnt.solve_OT_lowmem(r, c, C=None, X=None, Y=None, cost_fn=None,
gamma_f=1024., block_size=None,
drop_tiny=False, return_plan=False, round=True, log=False,
devices=None, num_gpus=None)
Exactly one of C or (X, Y) must be provided.
| Parameter | Type | Description |
|---|---|---|
r |
Tensor |
Row marginal of shape (n,), must sum to 1 |
c |
Tensor |
Column marginal of shape (m,), must sum to 1 |
C |
Tensor |
Dense cost matrix (n, m). Mutually exclusive with X/Y |
X |
Tensor |
Source points (n, d). Use together with Y |
Y |
Tensor |
Target points (m, d). Use together with X |
cost_fn |
callable |
(X, Y_block) -> C_block of shape (n, k). Defaults to squared Euclidean |
gamma_f |
float |
Temperature parameter. Default: 1024 |
block_size |
int |
Columns per block (controls memory/speed trade-off). Default: m |
drop_tiny |
bool |
Drop tiny marginal entries for speedup with sparse marginals |
return_plan |
bool |
If True, return transport plan instead of cost |
round |
bool |
If True, round solution onto feasible set |
log |
bool |
If True, also return optimization logs |
devices |
list |
GPU indices or torch.device objects to distribute column blocks across |
num_gpus |
int |
Number of GPUs to use (starting from index 0). Mutually exclusive with devices |
Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.
Performance Tips
- Use float64 for
gamma_f > 1024(automatic conversion with warning) - Normalize cost matrices to [0, 1] for numerical stability
- Use batched solver when solving multiple problems with shared structure
- Increase
gamma_ffor higher precision (error scales as O(log n / γ) in the worst case, but can be much better) - Use multi-GPU (
devices/num_gpus) for large point-cloud problems or large batches —solve_OT_lowmemwith column-parallel blocking achieves near-linear speedup (3.2× on 3 GPUs at n=8192, m=65536)
Citation
If you use MDOT-TNT in your research, please cite:
@inproceedings{kemertas2025truncated,
title={A Truncated Newton Method for Optimal Transport},
author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=gWrWUaCbMa}
}
License
This code is released under the BSD 3-Clause license..
Contact
For questions or issues, please open an issue or email: kemertas [at] cs [dot] toronto [dot] edu
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mdot_tnt-1.2.0.tar.gz.
File metadata
- Download URL: mdot_tnt-1.2.0.tar.gz
- Upload date:
- Size: 43.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
86aeb41c2df5442456ca03c7860538caa6dbb9d3b55b7145c9197aa1a5120248
|
|
| MD5 |
712bdb71e016dee31546dcfcc9030110
|
|
| BLAKE2b-256 |
832b9686bbaf0cf7e36f501ed7c9020eadf48ea12b80546505b6109d329c9465
|
File details
Details for the file mdot_tnt-1.2.0-py3-none-any.whl.
File metadata
- Download URL: mdot_tnt-1.2.0-py3-none-any.whl
- Upload date:
- Size: 35.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7c6ca132f15e82bbdf63cacb0459846baadd5586e8f981de99c6108185d21c93
|
|
| MD5 |
206421771ac00035deb56aeae19ee4a4
|
|
| BLAKE2b-256 |
b732ac8a3d8789bbb99d57337816a48a5a679a3cee99d55a7809fc868830b85d
|