Utilities for optimal transport algorithms and benchmarking
Project description
uot-bench
uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.
- Documentation index: docs/index.md
- Problems module: docs/problems.md
- Problem generators: docs/generators.md
- SLURM guide: docs/slurm.md
- Color transfer experiment: docs/color_transfer.md
- MNIST classification experiment: docs/mnist.md
Install
Core install:
pip install uot-bench
Extras (optional):
pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"
CUDA (optional, JAX):
pip install "uot-bench[cuda12]"
Quickstart (Python API)
A minimal two-marginal Sinkhorn example:
import numpy as np
from uot.data.measure import PointCloudMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared
# Create two 1D point-cloud measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()
mu = PointCloudMeasure(x, a, name="mu")
nu = PointCloudMeasure(y, b, name="nu")
problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()
inputs = problem.solver_inputs()
result = solver.solve(
marginals=inputs.marginals,
costs=inputs.costs,
reg=1e-2,
)
print("cost:", float(result["cost"]))
Benchmarking CLI tools (Pixi)
This project uses Pixi to manage benchmarking dependencies and tasks.
Installing Pixi
After installation run pixi install to set up the environment. Available
commands can be invoked with pixi run <task>.
Common commands
pixi run serialize --config <config.yaml> --export-dir <directory>pixi run benchmark --config <config.yaml> --folds <n> --export <file>pixi run lintorruff check .
Slurm
On Compute Canada clusters
The generic script for both SLURM and local runs is scripts/run_benchmark.sh.
Example:
sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml
Monitor GPU usage with:
srun --jobid 123456 --pty watch -n 30 nvidia-smi
Synthetic datasets
Create a generator config like:
generators:
1D-gaussians-64:
generator: uot.problems.generators.GaussianMixtureGenerator
dim: 1
num_components: 1
n_points: 64
num_datasets: 30
borders: (-6, 6)
cost_fn: uot.utils.costs.cost_euclid_squared
use_jax: true
seed: 42
Then serialize:
pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic
Running experiments
Example config:
param-grids:
regularizations:
- reg: 1
- reg: 0.001
solvers:
sinkhorn:
solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
jit: true
param-grid: regularizations
problems:
dir: datasets/synthetic
names:
- 1D-gaussians-64
experiment:
name: Measure on Gaussians
function: uot.experiments.measurement.measure_time
Run the benchmark:
pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv
Color Transfer
Example config:
param-grids:
epsilons:
- reg: 1
- reg: 0.01
solvers:
sinkhorn:
solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
param-grid: epsilons
jit: true
bin-number:
- 16
- 32
soft-extension:
- no
- yes
displacement-interpolation:
- 0.0
- 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42
drop-columns:
- transport_plan
- monge_map
- u_final
- v_final
experiment:
name: Time and test
output-dir: ./outputs/color_transfer
For detailed explanation of parameters, see docs/color_transfer.md.
Notably, soft-extension and displacement-interpolation can be single values
or lists, and the pipeline will run once per option.
Example: Lab space with selected channels.
color-space: lab
active-channels: [l, a]
Run the experiment:
pixi run color-transfer --config ./configs/color_transfer/example.yaml
Dashboard visualization:
pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>
MNIST Classification
The MNIST experiment is performed in two steps:
- Distance matrix calculation
- Classification
See docs/mnist.md for configs.
Step 1:
pixi run mnist_distances --config ./configs/mnist_dist_example.yaml
Step 2:
pixi run mnist_classification --config ./configs/mnist_classification_example.yaml
Linting
Run pixi run lint or ruff check ..
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file uot_bench-0.1.8.tar.gz.
File metadata
- Download URL: uot_bench-0.1.8.tar.gz
- Upload date:
- Size: 321.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
088093ac6e5d095c4ec14d5c183f950c52b4bd9614e94cc78018dd19896b8357
|
|
| MD5 |
2d08ef422b792014e813943d41beb462
|
|
| BLAKE2b-256 |
3fcbfc5cbd852cc74191aca085da6b0a3f5e2b0ccf1576746eaf0e6e41cf556b
|
File details
Details for the file uot_bench-0.1.8-py3-none-any.whl.
File metadata
- Download URL: uot_bench-0.1.8-py3-none-any.whl
- Upload date:
- Size: 364.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6258bf702bdb106e1c94e9feecb1c212b38b6b118977315cecb55e7621c606af
|
|
| MD5 |
85b03734b45e94a60abe6a3e88a0eec3
|
|
| BLAKE2b-256 |
a4534bacf9ef95ac6284bdfa45b56964320e3d0276b99c8e5e5df9235e034c37
|