Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

uot-bench

uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.

Install

Core install:

pip install uot-bench

Extras (optional):

pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"

CUDA (optional, JAX):

pip install "uot-bench[cuda12]"

Quickstart (Python API)

A minimal two-marginal Sinkhorn example:

import numpy as np

from uot.data.measure import PointCloudMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared

# Create two 1D point-cloud measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)

a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()

mu = PointCloudMeasure(x, a, name="mu")
nu = PointCloudMeasure(y, b, name="nu")

problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()
inputs = problem.solver_inputs()

result = solver.solve(
    marginals=inputs.marginals,
    costs=inputs.costs,
    reg=1e-2,
)

print("cost:", float(result["cost"]))

Benchmarking CLI tools (Pixi)

This project uses Pixi to manage benchmarking dependencies and tasks.

Installing Pixi

After installation run pixi install to set up the environment. Available commands can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory>
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file>
  • pixi run lint or ruff check .

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. Example:

sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

Monitor GPU usage with:

srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

Create a generator config like:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Then serialize:

pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

Running experiments

Example config:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Run the benchmark:

pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

Color Transfer

Example config:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of parameters, see docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists, and the pipeline will run once per option.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

Run the experiment:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

Dashboard visualization:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST experiment is performed in two steps:

  • Distance matrix calculation
  • Classification

See docs/mnist.md for configs.

Step 1:

pixi run mnist_distances --config ./configs/mnist_dist_example.yaml

Step 2:

pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

Run pixi run lint or ruff check ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.8.tar.gz (321.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.8-py3-none-any.whl (364.7 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.8.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.8.tar.gz
  • Upload date:
  • Size: 321.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.8.tar.gz
Algorithm Hash digest
SHA256 088093ac6e5d095c4ec14d5c183f950c52b4bd9614e94cc78018dd19896b8357
MD5 2d08ef422b792014e813943d41beb462
BLAKE2b-256 3fcbfc5cbd852cc74191aca085da6b0a3f5e2b0ccf1576746eaf0e6e41cf556b

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.8-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.8-py3-none-any.whl
  • Upload date:
  • Size: 364.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.8-py3-none-any.whl
Algorithm Hash digest
SHA256 6258bf702bdb106e1c94e9feecb1c212b38b6b118977315cecb55e7621c606af
MD5 85b03734b45e94a60abe6a3e88a0eec3
BLAKE2b-256 a4534bacf9ef95ac6284bdfa45b56964320e3d0276b99c8e5e5df9235e034c37

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page