Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

uot-bench

uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.

Install

Core install:

pip install uot-bench

Extras (optional):

pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"

CUDA (optional, JAX):

pip install "uot-bench[cuda12]"

Quickstart (Python API)

A minimal two-marginal Sinkhorn example:

import numpy as np

from uot.data.measure import DiscreteMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared

# Create two 1D discrete measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)

a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()

mu = DiscreteMeasure(x, a, name="mu")
nu = DiscreteMeasure(y, b, name="nu")

problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()

result = solver.solve(
    marginals=problem.get_marginals(),
    costs=problem.get_costs(),
    reg=1e-2,
)

print("cost:", float(result["cost"]))

Benchmarking CLI tools (Pixi)

This project uses Pixi to manage benchmarking dependencies and tasks.

Installing Pixi

After installation run pixi install to set up the environment. Available commands can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory>
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file>
  • pixi run lint or ruff check .

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. Example:

sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

Monitor GPU usage with:

srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

Create a generator config like:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Then serialize:

pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

Running experiments

Example config:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Run the benchmark:

pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

Color Transfer

Example config:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of parameters, see docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists, and the pipeline will run once per option.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

Run the experiment:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

Dashboard visualization:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST experiment is performed in two steps:

  • Distance matrix calculation
  • Classification

See docs/mnist.md for configs.

Step 1:

pixi run mnist_distances --config ./configs/mnist_dist_example.yaml

Step 2:

pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

Run pixi run lint or ruff check ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.4.tar.gz (207.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.4-py3-none-any.whl (255.8 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.4.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.4.tar.gz
  • Upload date:
  • Size: 207.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.4.tar.gz
Algorithm Hash digest
SHA256 51fbc1b431d7d8d9be6f4d9b188d8544029a5cd20e0f94643f58cdc800c7da13
MD5 d94b3a0ec11b4d158199a58d140530a5
BLAKE2b-256 7294a085494caa2ab72212f192054b827da5c79d81a3e464617d749fbde9e850

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 255.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 ee882577f02bb730e1f464974cf5eccd2f7ac74bab94ee73f63e7442446ef812
MD5 1f404dfd36085bc09310b527db920ccd
BLAKE2b-256 bb404b6fb6cbad511f9e1213dd759434a32b747b1b7484cd87f26f4563163a27

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page