Skip to main content

Optimal Transport Tools in JAX

Project description

logo

Optimal Transport Tools (OTT)

Downloads Tests Docs Coverage

See the full documentation.

What is OTT-JAX?

A JAX powered library to compute optimal transport at scale and on accelerators, OTT-JAX includes the fastest implementation of the Sinkhorn algorithm you will find around. We have implemented all tweaks (scheduling, momentum, acceleration, initializations) and extensions (low-rank, entropic maps). They can be used directly between two datasets, or within more advanced problems (Gromov-Wasserstein, barycenters). Some of JAX features, including JIT, auto-vectorization and implicit differentiation work towards the goal of having end-to-end differentiable outputs. OTT-JAX is led by a team of researchers at Apple, with contributions from Google and Meta researchers, as well as many academic partners, including TU München, Oxford, ENSAE/IP Paris, ENS Paris and the Hebrew University.

Installation

Install OTT-JAX from PyPI as:

pip install ott-jax

or with conda via conda-forge as:

conda install -c conda-forge ott-jax

What is optimal transport?

Optimal transport can be loosely described as the branch of mathematics and optimization that studies matching problems: given two families of points, and a cost function on pairs of points, find a "good" (low cost) way to associate bijectively to every point in the first family another in the second.

Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of $n$ points using a pairwise cost can be solved with the Hungarian algorithm, solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes.

Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems.

In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance):

Example

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.key(0), 4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings using the Sinkhorn algorithm.
geom = pointcloud.PointCloud(x, y)
prob = linear_problem.LinearProblem(geom, a, b)

solver = sinkhorn.Sinkhorn()
out = solver(prob)

The call to solver(prob) above works out the optimal transport solution. The out object contains a transport matrix (here of size $12\times 14$) that quantifies the association strength between each point of the first point cloud, to one or more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost functions, objectives, and solvers, as detailed in the full documentation.

obtained coupling

Citation

If you have found this work useful, please consider citing this reference:

@article{cuturi2022optimal,
  title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
  author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
          Davis, Geoff and Teboul, Olivier},
  journal={arXiv preprint arXiv:2201.12324},
  year={2022}
}

See also

The moscot package for OT analysis of multi-omics data also uses OTT as a backbone.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ott_jax-0.5.0.tar.gz (214.6 kB view details)

Uploaded Source

Built Distribution

ott_jax-0.5.0-py3-none-any.whl (283.7 kB view details)

Uploaded Python 3

File details

Details for the file ott_jax-0.5.0.tar.gz.

File metadata

  • Download URL: ott_jax-0.5.0.tar.gz
  • Upload date:
  • Size: 214.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for ott_jax-0.5.0.tar.gz
Algorithm Hash digest
SHA256 0aaea41675d4c62106e66a2b241ab8a71ff6977efe1539bb865e537ff0765e52
MD5 52ed79a2a6f2c4b51897a03574fcd9a8
BLAKE2b-256 e8b11aff002b1c1c283deca1c0cd6401751917156b2014dd9d807a1201a93790

See more details on using hashes here.

Provenance

The following attestation bundles were made for ott_jax-0.5.0.tar.gz:

Publisher: publish.yml on ott-jax/ott

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file ott_jax-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: ott_jax-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 283.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.1 CPython/3.12.7

File hashes

Hashes for ott_jax-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 53576f72ec1a523f83edb29fd496f6c3669172cbc9957f0165afb9a48cd1b7b6
MD5 99ec05417c8a724890b0f2c99acd61e5
BLAKE2b-256 29323d9f3a091bbf1c6b2ac7a8d61916692dfd4af23a692ad001e99ea27e3356

See more details on using hashes here.

Provenance

The following attestation bundles were made for ott_jax-0.5.0-py3-none-any.whl:

Publisher: publish.yml on ott-jax/ott

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page