Skip to main content

Utilities for optimal transport algorithms and benchmarking

Project description

Utils for OT Methods Benchmark

Installing Pixi

This project uses Pixi to manage dependencies. Follow the official installation instructions for your platform:

After installation run pixi install to set up the environment. Available tasks can be invoked with pixi run <task>.

Common commands

  • pixi run serialize --config <config.yaml> --export-dir <directory> - create problem datasets from config.yaml in the target directory.
  • pixi run benchmark --config <config.yaml> --folds <n> --export <file> - run experiments using the configuration for n folds and write results to file.
  • pixi run lint or ruff check . to lint the code.

Slurm

On Compute Canada clusters

The generic script for both SLURM and local runs is scripts/run_benchmark.sh. For example: sbatch scripts/run_benchmark.sh configs/generators/gaussians.yaml configs/runners/gaussians.yaml

One can monitor the GPU usage on the node with the following command, which runs nvidia-smi every 30 seconds

$ srun --jobid 123456 --pty watch -n 30 nvidia-smi

Synthetic datasets

To create synthetic dataset first need to create config file for generation:

generators:
  1D-gaussians-64:
    generator: uot.problems.generators.GaussianMixtureGenerator
    dim: 1
    num_components: 1
    n_points: 64
    num_datasets: 30
    borders: (-6, 6)
    cost_fn: uot.utils.costs.cost_euclid_squared
    use_jax: true
    seed: 42

Class specified in generator field will be used and all other fields will be passed as init arguments to it. Section name (in this case 1D-gaussians-64) will be used as generator name. Multiple generators in one config are allowed.

$ pixi run serialize --config configs/generators/gaussians.yaml --export-dir datasets/synthetic

In export-dir folders with serialized problems for each generators will be created. In the same folder meta.yaml will be created with copy of generator config.

Running experiments

To run experiments, first create config file like:

param-grids:
  regularizations:
    - reg: 1
    - reg: 0.001

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    jit: true
    param-grid: regularizations

problems:
  dir: datasets/synthetic
  names:
    - 1D-gaussians-64
  
experiment: 
  name: Measure on Gaussians
  function: uot.experiments.measurement.measure_time

Here you can define solvers and their param-grids (solver will be run for each set of params). Also in problems section with dir the export-dir of serialization is specified (see previous section) and with names specific folders with problems in that directory

$ pixi run benchmark --config configs/runners/gaussians.yaml --folds 1 --export results/raw/gaussians.csv

With export one can secify where to put csv-report of experiment

Color Transfer

To run a Color Transfer experiment, first create config file like:

param-grids:
  epsilons:
    - reg: 1
    - reg: 0.01

solvers:
  sinkhorn:
    solver: uot.solvers.sinkhorn.SinkhornTwoMarginalSolver
    param-grid: epsilons
    jit: true

bin-number:
  - 16
  - 32
soft-extension:
  - no
  - yes
displacement-interpolation:
  - 0.0
  - 1.0
color-space: rgb
# active-channels: [r, g]
batch-size: 100000
pair-number: 3
images-dir: ./datasets/images
rng-seed: 42

drop-columns:
  - transport_plan
  - monge_map
  - u_final
  - v_final

experiment: 
  name: Time and test
  output-dir: ./outputs/color_transfer

For detailed explanation of the parameters, please refer to docs/color_transfer.md. Notably, soft-extension and displacement-interpolation can be single values or lists (e.g. both no and yes) and the pipeline will run once per option, tagging each row with the applied setting.

Example: Lab space with selected channels.

color-space: lab
active-channels: [l, a]

The corresponding pixi command example:

pixi run color-transfer --config ./configs/color_transfer/example.yaml

There is also a feature to create a dashboard for visual comparison of the input images and results - the corresponding command is:

pixi run color-transfer-visualization --origin_folder <path_to_input_images> --results_folder <path_to_resulting_images>

MNIST Classification

The MNIST classification experiment is performed in two steps.

  • Distance matrix calculation.
  • Classification itself.

For detailed config examples for each of them, please refer to docs/mnist.md.

The corresponding pixi commands:

  • Step 1:
pixi run mnist_distances --config ./configs/mnist_dist_example.yaml
  • Step 2:
pixi run mnist_classification --config ./configs/mnist_classification_example.yaml

Linting

This project uses Black and Ruff for code style. Run pixi run lint or ruff check . to lint the code.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

uot_bench-0.1.0.tar.gz (196.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

uot_bench-0.1.0-py3-none-any.whl (242.4 kB view details)

Uploaded Python 3

File details

Details for the file uot_bench-0.1.0.tar.gz.

File metadata

  • Download URL: uot_bench-0.1.0.tar.gz
  • Upload date:
  • Size: 196.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b4c281f79ccac434aaf31b516abb466dcee7a59250da8e0fe5242f2d2b9af8ee
MD5 da78ad924dc92cc0be09a7e3de388ebf
BLAKE2b-256 3009f9e0c6d6e21278b66234afa5ae9841d1569252376b0e8e5cd7538df4b1cf

See more details on using hashes here.

File details

Details for the file uot_bench-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: uot_bench-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 242.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for uot_bench-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c6f8afc22cd8517b3a744eb0892ea284223f8435ab00f32ccdae6b8072f87bd8
MD5 98bacf8145dbf47a0084f8926eb6de65
BLAKE2b-256 751c4a1bf92f712af80649c5332fabdba8723486423594318573812f12e82512

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page