Skip to main content

CUDA-Accelerated Regularized Optimal Transport

Project description

RegOT-CUDA

RegOT-CUDA (cuRegOT) is a CUDA-accelerated library for optimal transport computation, providing high-performance implementations of regularized optimal transport algorithms.

As a complement to this library, the RegOT-Python repository provides efficient CPU-based solvers for regularized optimal transport.

Problem Formulation

Currently RegOT-CUDA solves the entropic-regularized optimal transport problem:

\begin{align*}
\min_{T\in\mathbb{R}^{n\times m}}\quad & \langle T,M\rangle-\eta\cdot h(T),\\
\text{subject to}\quad & T\mathbf{1}_{m}=a,T^{T}\mathbf{1}_{n}=b,T\ge0,
\end{align*}

where $a\in\mathbb{R}^n$ and $b\in\mathbb{R}^m$ are two given probability vectors with $a_i>0$, $b_j>0$, $\sum_{i=1}^n a_i=\sum_{j=1}^m b_j=1$, and $M\in\mathbb{R}^{n\times m}$ is a given cost matrix. The function $h(T)=\sum_{i=1}^{n}\sum_{j=1}^{m}T_{ij}(1-\log T_{ij})$ is the entropy term, and $\eta>0$ is a regularization parameter.

Work in Progress

RegOT-CUDA is a work in progress. Currently we have implemented the block coordinate descent algorithm (BCD, equilavent to the well-known Sinkhorn algorithm) and the sparse-plus-low-rank quasi-Newton method (SPLR) for entropic-regularized optimal transport. More state-of-the-art solvers are under development, and a list of candidate algorithms can be found in the RegOT-Python package.

Requirements

  • Python >= 3.11
  • NumPy >= 1.23.0
  • PyTorch >= 2.10 (optional, for PyTorch interface)
  • CUDA Toolkit >= 13.0
  • The cuDSS library
  • Compatible NVIDIA GPU
  • C++ compiler (C++11 or higher, for building from source)

Installation

Using pip

You can simply install cuRegOT using the pip command, and dependent CUDA runtime libraries will be automatically installed:

pip install curegot

Building from source

Environment Setup

The CUDA development environment is required to build this package from source. You can install the C++ compiler and CUDA libraries using Conda:

# Create CUDA development environment
conda create -n nvdev
conda activate nvdev
conda install python=3.12 gxx_linux-64
conda install -c nvidia cuda-toolkit=13.2 libcudss-dev cuda-nvtx-dev
pip install torch --index-url https://download.pytorch.org/whl/cu130
pip install numpy requests pybind11

You also need to set the CUDA_HOME environment variable, for example:

export CUDA_HOME=/usr/local/cuda

If you use the Conda installation method above, you can run the following command to set the environment variable for the virtual environment:

conda activate nvdev
conda env config vars set CUDA_HOME="<path_to_conda>/envs/nvdev/"

Build and Install

cd regot-cuda
pip install --no-build-isolation .

Verify Installation

python -c "import curegot; print('RegOT-CUDA imported successfully')"

Usage

NumPy interface:

import numpy as np
import curegot

# Create data
np.random.seed(123)
n, m = 100, 80
M = np.random.rand(n, m)  # Cost matrix
a = np.random.rand(n)     # Source distribution
a = a / np.sum(a)         # Normalize
b = np.random.rand(m)     # Target distribution
b = b / np.sum(b)         # Normalize
reg = 0.1                 # Regularization parameter

# Call algorithm
result1 = curegot.numpy.sinkhorn_bcd(M, a, b, reg, tol=1e-6, max_iter=1000, verbose=0)
plan1 = result1["plan"]
print(plan1[:3, :3])

result2 = curegot.numpy.sinkhorn_splr(M, a, b, reg, tol=1e-6, max_iter=1000, verbose=0)
plan2 = result2["plan"]
print(plan2[:3, :3])

PyTorch interface:

import torch
import curegot

# Create data
torch.manual_seed(123)
n, m = 100, 80
device = "cuda" if torch.cuda.is_available() else "cpu"
M = torch.rand(n, m, device=device)  # Cost matrix
a = torch.rand(n, device=device)     # Source distribution
a = a / torch.sum(a)                 # Normalize
b = torch.rand(m, device=device)     # Target distribution
b = b / torch.sum(b)                 # Normalize
reg = 0.1                            # Regularization parameter

# Call algorithm
result1 = curegot.torch.sinkhorn_bcd(M, a, b, reg, tol=1e-6, max_iter=1000, verbose=0)
plan1 = result1["plan"]
print(plan1[:3, :3])

result2 = curegot.torch.sinkhorn_splr(M, a, b, reg, tol=1e-6, max_iter=1000, verbose=0)
plan2 = result2["plan"]
print(plan2[:3, :3])

Tests

cd regot-cuda/test
pip install regot
python test_sinkhorn_bcd.py
python test_sinkhorn_splr.py

Benchmark

The plot below shows the benchmark result for some popular GPU-based solvers. The horizontal axis represents the elapsed wall time, and the vertical axis is the optimization error on a logarithmic scale. Lower curves indicate better performance, achieving lower errors in less time.

More details on the benchmark can be found in the paper.

Bibliography

Please consider to cite our work if you find our algorithms or software useful in your research and applications.

@inproceedings{qiu2026curegot,
  title={{{cuRegOT}}: A {{GPU}}-Accelerated Solver for Entropic-Regularized Optimal Transport},
  author={Qiu, Yixuan},
  booktitle={Forty-third International Conference on Machine Learning},
  year={2026}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

curegot-0.1.0.tar.gz (711.5 kB view details)

Uploaded Source

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

curegot-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (8.6 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

curegot-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (8.6 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

curegot-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (8.5 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.28+ x86-64

File details

Details for the file curegot-0.1.0.tar.gz.

File metadata

  • Download URL: curegot-0.1.0.tar.gz
  • Upload date:
  • Size: 711.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for curegot-0.1.0.tar.gz
Algorithm Hash digest
SHA256 790556fb0b354989a26aaa0371cbdaf5c392491493958d32f97a3c75f771d558
MD5 11cfc5e081ffb8776ba096ee3cd9c99e
BLAKE2b-256 6b9032664f999e6cec477fb403bf876d6eda18237847249422b4cd38ff06c826

See more details on using hashes here.

Provenance

The following attestation bundles were made for curegot-0.1.0.tar.gz:

Publisher: wheels.yml on yixuan/regot-cuda

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file curegot-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for curegot-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 f5d14e2fd5544a2b23aa7667cd21bd359d81eb7e2b6ca8a52776267e6db52f8b
MD5 84d90d067136582f204d53dbb98e2a20
BLAKE2b-256 772214f68c60283684bd64f8d044d59032e17f536f7f82171155ca600b3c09d3

See more details on using hashes here.

Provenance

The following attestation bundles were made for curegot-0.1.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: wheels.yml on yixuan/regot-cuda

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file curegot-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for curegot-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 5758fbfeacc7a54f64a101f2aa7ab0caf8c73367edacbb26a7c61af2c13ded76
MD5 dd333d221dabf0ebfd069abee9e5ad7d
BLAKE2b-256 0b2053d3f84a35993a6b2b2b6b7dda28b591bac6df1c0bb3999210faf906635d

See more details on using hashes here.

Provenance

The following attestation bundles were made for curegot-0.1.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: wheels.yml on yixuan/regot-cuda

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file curegot-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.

File metadata

File hashes

Hashes for curegot-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
Algorithm Hash digest
SHA256 c8e424053dc773667c98e9f058f1731f7b9ebac34efe9f35cbe2d0139c82fdbb
MD5 c3acab312de8744fe3d4bcffc644eba3
BLAKE2b-256 e30310c9dad68c7e391f190fa68cb42537954af9a2002905ac7a4fc9840ba90f

See more details on using hashes here.

Provenance

The following attestation bundles were made for curegot-0.1.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl:

Publisher: wheels.yml on yixuan/regot-cuda

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page