Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

uot-bench

uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.

Install

Core install:

pip install uot-bench

Extras (optional):

pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"

CUDA (optional, JAX):

pip install "uot-bench[cuda12]"

Quickstart (Python API)

A minimal two-marginal Sinkhorn example:

import numpy as np

from uot.data.measure import PointCloudMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared

# Create two 1D point-cloud measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)

a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()

mu = PointCloudMeasure(x, a, name="mu")
nu = PointCloudMeasure(y, b, name="nu")

problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()
inputs = problem.solver_inputs()

result = solver.solve(
    marginals=inputs.marginals,
    costs=inputs.costs,
    reg=1e-2,
)

print("cost:", float(result["cost"]))

Benchmarking CLI tools (Pixi)

This project uses Pixi to manage benchmarking dependencies and tasks.

Installing Pixi

After installation run pixi install to set up the environment. Available commands can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory>
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file>
  • pixi run lint or ruff check .

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. Example:

sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

Monitor GPU usage with:

srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

Create a generator config like:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Then serialize:

pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

Running experiments

Example config:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Run the benchmark:

pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

Color Transfer

Example config:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of parameters, see docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists, and the pipeline will run once per option.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

Run the experiment:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

Dashboard visualization:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST experiment is performed in two steps:

  • Distance matrix calculation
  • Classification

See docs/mnist.md for configs.

Step 1:

pixi run mnist_distances --config ./configs/mnist_dist_example.yaml

Step 2:

pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

Run pixi run lint or ruff check ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.5.tar.gz (320.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.5-py3-none-any.whl (365.7 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.5.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.5.tar.gz
  • Upload date:
  • Size: 320.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.5.tar.gz
Algorithm Hash digest
SHA256 0bdf040c2b3293b43926e5fb755407d9655f81d827b432131fb618542b386d7d
MD5 f5c53431cda0ddbf4c0301e9664de494
BLAKE2b-256 3b379015125d564a6d8e7f996f7819d09c26b6c75a9d5f902471bae238b9de19

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 365.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 9d3e23ceb18bff60cdc7c567d987ee357f3b415e5a1b37531c0429fe572ddf4a
MD5 e39f762c96805a7d47bb96df6abcb197
BLAKE2b-256 71d3c346e24840d69ace52ae0b13b919a526d1e2e011d6395d29a184c105b2f4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page