Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

uot-bench

uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.

Install

Core install:

pip install uot-bench

Extras (optional):

pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"

CUDA (optional, JAX):

pip install "uot-bench[cuda12]"

Quickstart (Python API)

A minimal two-marginal Sinkhorn example:

import numpy as np

from uot.data.measure import DiscreteMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared

# Create two 1D discrete measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)

a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()

mu = DiscreteMeasure(x, a, name="mu")
nu = DiscreteMeasure(y, b, name="nu")

problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()

result = solver.solve(
    marginals=problem.get_marginals(),
    costs=problem.get_costs(),
    reg=1e-2,
)

print("cost:", float(result["cost"]))

Benchmarking CLI tools (Pixi)

This project uses Pixi to manage benchmarking dependencies and tasks.

Installing Pixi

After installation run pixi install to set up the environment. Available commands can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory>
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file>
  • pixi run lint or ruff check .

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. Example:

sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

Monitor GPU usage with:

srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

Create a generator config like:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Then serialize:

pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

Running experiments

Example config:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Run the benchmark:

pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

Color Transfer

Example config:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of parameters, see docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists, and the pipeline will run once per option.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

Run the experiment:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

Dashboard visualization:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST experiment is performed in two steps:

  • Distance matrix calculation
  • Classification

See docs/mnist.md for configs.

Step 1:

pixi run mnist_distances --config ./configs/mnist_dist_example.yaml

Step 2:

pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

Run pixi run lint or ruff check ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.1.tar.gz (197.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.1-py3-none-any.whl (244.8 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.1.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.1.tar.gz
  • Upload date:
  • Size: 197.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.1.tar.gz
Algorithm Hash digest
SHA256 455120cd9f7a340a6d54662f67fe836e0ac1a4ada347e4cb89f7c179cf4d2351
MD5 f802c1467ac12a0db860197bcb025bc8
BLAKE2b-256 1e209e445ebdd594ae4d4e325b393d067f75ba6648f1af4e195ddbe1f9aaaffa

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 244.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ee3ebb77570848b3c48bd8e20349dfb140c9a9916c96e4fca090d29785753c4
MD5 a332dc4b15413522e255bb74bee13a4a
BLAKE2b-256 0005c9bf16a0023f1f1a6692b291e5b8d38ef33d4665d49bb614ab8c79062298

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page