Skip to main content

Optimal Transport Tools in JAX.

Project description

logo

Optimal Transport Tools (OTT)

Downloads Tests Docs Coverage

See the full documentation.

What is OTT-JAX?

A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, acceleration, initializations) and extensions (low-rank), that can be used directly, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including JIT, auto-vectorization and implicit differentiation work towards the goal of having end-to-end differentiable outputs. OTT-JAX is developed by a team of researchers from Apple, Google, Meta and many academic contributors, including TU München, Oxford, ENSAE/IP Paris and the Hebrew University.

Installation

Install OTT-JAX from PyPI as:

pip install ott-jax

or with conda via conda-forge as:

conda install -c conda-forge ott-jax

What is optimal transport?

Optimal transport can be loosely described as the branch of mathematics and optimization that studies matching problems: given two families of points, and a cost function on pairs of points, find a "good" (low cost) way to associate bijectively to every point in the first family another in the second.

Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of $n$ points using a pairwise cost can be solved with the Hungarian algorithm, solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes.

Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems.

In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance):

Example

import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
ot = transport.solve(x, y, a=a, b=b)
P = ot.matrix

The call to solve above works out the optimal transport solution. The ot object contains a transport matrix (here of size $12\times 14$) that quantifies a link strength between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. In this toy example, most choices were arbitrary, and are reflected in the crude solve API. We provide far more flexibility to define custom cost functions, objectives, and solvers, as detailed in the full documentation.

obtained coupling

Citation

If you have found this work useful, please consider citing this reference:

@article{cuturi2022optimal,
  title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
  author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
          Davis, Geoff and Teboul, Olivier},
  journal={arXiv preprint arXiv:2201.12324},
  year={2022}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ott-jax-0.3.1.tar.gz (158.4 kB view details)

Uploaded Source

Built Distribution

ott_jax-0.3.1-py3-none-any.whl (193.9 kB view details)

Uploaded Python 3

File details

Details for the file ott-jax-0.3.1.tar.gz.

File metadata

  • Download URL: ott-jax-0.3.1.tar.gz
  • Upload date:
  • Size: 158.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for ott-jax-0.3.1.tar.gz
Algorithm Hash digest
SHA256 9bb555dccffa9b58e157a649808096e6dc2c46c2bfe6d12ec758d6114bf8af87
MD5 470437546f5d07eef0473aad720c52a0
BLAKE2b-256 a8f24fa930ad1021f7477ac4652e1cad4158e10898450e884ab8cfc1dc1c43f3

See more details on using hashes here.

File details

Details for the file ott_jax-0.3.1-py3-none-any.whl.

File metadata

  • Download URL: ott_jax-0.3.1-py3-none-any.whl
  • Upload date:
  • Size: 193.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.1

File hashes

Hashes for ott_jax-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 7eb01380a067af8d03e21e917d6a2f0dd01ca255c2d7e53c50665fd36d84e813
MD5 fecbbf15c4c5800276dfc92d496c44f4
BLAKE2b-256 0f5aab873f4c4c828c4a70b1a24628e99b3467224b4cff3c726773c7b7c94530

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page