A fast, GPU-parallel, PyTorch-compatible optimal transport solver.
Project description
MDOT-TNT
A Truncated Newton Method for Optimal Transport
A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.
Features
- High Precision: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
- GPU Accelerated: Fully compatible with CUDA for fast computation on large problems
- Batched Solving: Solve multiple OT problems simultaneously in batched mode
- Memory Efficient: Log-domain computations and efficient rounding avoid storing full transport plans
- PyTorch Native: Seamless integration with PyTorch, supporting autograd-compatible inputs
Installation
Prerequisites: Install PyTorch for your system configuration first.
pip install mdot-tnt
For development:
git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"
Quick Start
Single Problem
import torch
import mdot_tnt
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()
# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)
# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)
Batched Solving
When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:
import torch
import mdot_tnt
device = "cuda"
batch_size, n, m = 32, 512, 512
# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)
# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)
# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024) # Returns (batch_size,) tensor
The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.
API Reference
solve_OT
mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
| Parameter | Type | Description |
|---|---|---|
r |
Tensor |
Row marginal of shape (n,), must sum to 1 |
c |
Tensor |
Column marginal of shape (m,), must sum to 1 |
C |
Tensor |
Cost matrix of shape (n, m), recommended to normalize to [0, 1] |
gamma_f |
float |
Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024 |
return_plan |
bool |
If True, return transport plan instead of cost |
round |
bool |
If True, round solution onto feasible set |
log |
bool |
If True, also return optimization logs |
Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.
solve_OT_batched
mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
Same parameters as solve_OT, but with batched inputs:
r: Shape(batch, n)c: Shape(batch, m)C: Shape(n, m)for shared cost, or(batch, n, m)for per-problem costs
Returns: Costs (batch,) or plans (batch, n, m).
Performance Tips
- Use float64 for
gamma_f > 1024(automatic conversion with warning) - Normalize cost matrices to [0, 1] for numerical stability
- Use batched solver when solving multiple problems with shared structure
- Increase
gamma_ffor higher precision (error scales as O(log n / γ) in the worst case, but can be much better)
Citation
If you use MDOT-TNT in your research, please cite:
@inproceedings{kemertas2025truncated,
title={A Truncated Newton Method for Optimal Transport},
author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=gWrWUaCbMa}
}
License
This code is released under a non-commercial use license. For commercial licensing inquiries, please contact the authors.
Contact
For questions or issues, please open an issue or email: kemertas [at] cs [dot] toronto [dot] edu
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file mdot_tnt-1.0.0.tar.gz.
File metadata
- Download URL: mdot_tnt-1.0.0.tar.gz
- Upload date:
- Size: 24.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dca54c0e89a955bdb6e3419eadea07b6edb509a4a44565d0d9717d3917b46d36
|
|
| MD5 |
16fc9c67fb30593a900595c600fd3b44
|
|
| BLAKE2b-256 |
4d756802b900e6ae9a0be7abc9c7fd4b16178e831de9a99b581d58f4328203e4
|
File details
Details for the file mdot_tnt-1.0.0-py3-none-any.whl.
File metadata
- Download URL: mdot_tnt-1.0.0-py3-none-any.whl
- Upload date:
- Size: 21.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8595ebce04bbc2a1df8ad936ea0180b01536b5fdb5e6bee2150df4901766d456
|
|
| MD5 |
eb27eb28624a83743b68685dff834560
|
|
| BLAKE2b-256 |
21af7febf2e342e31bbfec65e6bebe6352539aac10a498f7b5b7e9e0fbbf36ab
|