Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

uot-bench

uot-bench is a Python toolkit for optimal transport solvers and benchmarking. It provides JAX-first implementations of common OT methods, utilities for generating problems and measures, and a configurable benchmarking pipeline for running experiments at scale.

Install

Core install:

pip install uot-bench

Extras (optional):

pip install "uot-bench[viz,image-analysis,color-transfer,gurobi]"

CUDA (optional, JAX):

pip install "uot-bench[cuda12]"

Quickstart (Python API)

A minimal two-marginal Sinkhorn example:

import numpy as np

from uot.data.measure import PointCloudMeasure
from uot.problems.two_marginal import TwoMarginalProblem
from uot.solvers.sinkhorn import SinkhornTwoMarginalSolver
from uot.utils.costs import cost_euclid_squared

# Create two 1D point-cloud measures
x = np.linspace(0.0, 1.0, 64).reshape(-1, 1)
y = np.linspace(0.0, 1.0, 64).reshape(-1, 1)

a = np.exp(-((x - 0.3) ** 2) / 0.01).reshape(-1)
a = a / a.sum()
b = np.exp(-((y - 0.7) ** 2) / 0.02).reshape(-1)
b = b / b.sum()

mu = PointCloudMeasure(x, a, name="mu")
nu = PointCloudMeasure(y, b, name="nu")

problem = TwoMarginalProblem("toy", mu, nu, cost_euclid_squared)
solver = SinkhornTwoMarginalSolver()
inputs = problem.solver_inputs()

result = solver.solve(
    marginals=inputs.marginals,
    costs=inputs.costs,
    reg=1e-2,
)

print("cost:", float(result["cost"]))

Benchmarking CLI tools (Pixi)

This project uses Pixi to manage benchmarking dependencies and tasks.

Installing Pixi

After installation run pixi install to set up the environment. Available commands can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory>
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file>
  • pixi run lint or ruff check .

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. Example:

sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

Monitor GPU usage with:

srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

Create a generator config like:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Then serialize:

pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

Running experiments

Example config:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Run the benchmark:

pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

Color Transfer

Example config:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of parameters, see docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists, and the pipeline will run once per option.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

Run the experiment:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

Dashboard visualization:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST experiment is performed in two steps:

  • Distance matrix calculation
  • Classification

See docs/mnist.md for configs.

Step 1:

pixi run mnist_distances --config ./configs/mnist_dist_example.yaml

Step 2:

pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

Run pixi run lint or ruff check ..

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.7.tar.gz (321.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.7-py3-none-any.whl (364.6 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.7.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.7.tar.gz
  • Upload date:
  • Size: 321.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.7.tar.gz
Algorithm Hash digest
SHA256 520e8d02e62041c7f754da8e3e229c5357dcbc6c7e81fc1ba0271bd58b31695e
MD5 a39d4035e4bdccad807433c677b642dd
BLAKE2b-256 8676c40f102ef608fb6f6595a4ba981956bd68b7c4972c62fd945fa1676af009

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 364.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 4677e115df7141bd278f79a794443b5df1cf279492d3662753038349d900743a
MD5 6ac65ea6a7ec7ce1f49438544d518dc2
BLAKE2b-256 f0f69159326122612dd2995073f878218942989c466102cfec054df1692bd9ec

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page