Skip to main content

A fast, GPU-parallel, PyTorch-compatible optimal transport solver.

Project description

MDOT-TNT

MDOT-TNT Logo

A Truncated Newton Method for Optimal Transport

PyPI version Python 3.8+ License

A fast, GPU-accelerated solver for entropic-regularized optimal transport (OT) problems. MDOT-TNT combines mirror descent with a truncated Newton projection method to achieve high numerical precision while remaining stable under weak regularization.


Features

  • High Precision: Stable under extremely weak regularization (γ up to 2¹⁸), enabling highly precise approximations of unregularized OT
  • GPU Accelerated: Fully compatible with CUDA for fast computation on large problems
  • Batched Solving: Solve multiple OT problems simultaneously in batched mode
  • Memory Efficient: Log-domain computations and efficient rounding avoid storing full transport plans
  • PyTorch Native: Seamless integration with PyTorch, supporting autograd-compatible inputs

Installation

Prerequisites: Install PyTorch for your system configuration first.

pip install mdot-tnt

For development:

git clone https://github.com/metekemertas/mdot_tnt.git
cd mdot_tnt
pip install -e ".[dev]"

Quick Start

Single Problem

import torch
import mdot_tnt

device = "cuda" if torch.cuda.is_available() else "cpu"

# Create marginals (probability distributions)
n, m = 512, 512
r = torch.rand(n, device=device, dtype=torch.float64)
r = r / r.sum()
c = torch.rand(m, device=device, dtype=torch.float64)
c = c / c.sum()

# Cost matrix (e.g., pairwise distances)
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve for optimal transport cost
cost = mdot_tnt.solve_OT(r, c, C, gamma_f=1024)

# Or get the full transport plan
plan = mdot_tnt.solve_OT(r, c, C, gamma_f=1024, return_plan=True)

Batched Solving

When solving multiple OT problems, use the batched solver for significant speedup compared to sequential solution:

import torch
import mdot_tnt

device = "cuda"
batch_size, n, m = 32, 512, 512

# Multiple marginal pairs
r = torch.rand(batch_size, n, device=device, dtype=torch.float64)
r = r / r.sum(-1, keepdim=True)
c = torch.rand(batch_size, m, device=device, dtype=torch.float64)
c = c / c.sum(-1, keepdim=True)

# Shared cost matrix (or per-problem: shape [batch_size, n, m])
C = torch.rand(n, m, device=device, dtype=torch.float64)

# Solve all problems at once
costs = mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024)  # Returns (batch_size,) tensor

The batched solver achieves speedup by amortizing GPU synchronization overhead across all problems in the batch.

API Reference

solve_OT

mdot_tnt.solve_OT(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)
Parameter Type Description
r Tensor Row marginal of shape (n,), must sum to 1
c Tensor Column marginal of shape (m,), must sum to 1
C Tensor Cost matrix of shape (n, m), recommended to normalize to [0, 1]
gamma_f float Temperature parameter (inverse regularization). Higher = more accurate. Default: 1024
return_plan bool If True, return transport plan instead of cost
round bool If True, round solution onto feasible set
log bool If True, also return optimization logs

Returns: Transport cost (scalar) or plan (n, m), optionally with logs dict.

solve_OT_batched

mdot_tnt.solve_OT_batched(r, c, C, gamma_f=1024., return_plan=False, round=True, log=False)

Same parameters as solve_OT, but with batched inputs:

  • r: Shape (batch, n)
  • c: Shape (batch, m)
  • C: Shape (n, m) for shared cost, or (batch, n, m) for per-problem costs

Returns: Costs (batch,) or plans (batch, n, m).

Performance Tips

  1. Use float64 for gamma_f > 1024 (automatic conversion with warning)
  2. Normalize cost matrices to [0, 1] for numerical stability
  3. Use batched solver when solving multiple problems with shared structure
  4. Increase gamma_f for higher precision (error scales as O(log n / γ) in the worst case, but can be much better)

Citation

If you use MDOT-TNT in your research, please cite:

@inproceedings{kemertas2025truncated,
  title={A Truncated Newton Method for Optimal Transport},
  author={Kemertas, Mete and Farahmand, Amir-massoud and Jepson, Allan Douglas},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=gWrWUaCbMa}
}

License

This code is released under a non-commercial use license. For commercial licensing inquiries, please contact the authors.

Contact

For questions or issues, please open an issue or email: kemertas [at] cs [dot] toronto [dot] edu

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mdot_tnt-1.0.0.tar.gz (24.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

mdot_tnt-1.0.0-py3-none-any.whl (21.5 kB view details)

Uploaded Python 3

File details

Details for the file mdot_tnt-1.0.0.tar.gz.

File metadata

  • Download URL: mdot_tnt-1.0.0.tar.gz
  • Upload date:
  • Size: 24.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mdot_tnt-1.0.0.tar.gz
Algorithm Hash digest
SHA256 dca54c0e89a955bdb6e3419eadea07b6edb509a4a44565d0d9717d3917b46d36
MD5 16fc9c67fb30593a900595c600fd3b44
BLAKE2b-256 4d756802b900e6ae9a0be7abc9c7fd4b16178e831de9a99b581d58f4328203e4

See more details on using hashes here.

File details

Details for the file mdot_tnt-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: mdot_tnt-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 21.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for mdot_tnt-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8595ebce04bbc2a1df8ad936ea0180b01536b5fdb5e6bee2150df4901766d456
MD5 eb27eb28624a83743b68685dff834560
BLAKE2b-256 21af7febf2e342e31bbfec65e6bebe6352539aac10a498f7b5b7e9e0fbbf36ab

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page