Optimal Transport Tools in JAX
Project description
Optimal Transport Tools (OTT)
See the full documentation.
What is OTT-JAX?
A JAX powered library to solve a wide variety of problems leveraging optimal transport theory, at scale and on accelerators.
In particular, OTT-JAX implements various discrete solvers to match two point clouds, notably the Sinkhorn algorithm implemented to work on various geometric domains and sped up using various tweaks (scheduling, momentum, acceleration, initializations) and extensions (low-rank).
These algorithms power the resolution of more advanced problems (Gromov-Wasserstein, Wasserstein barycenter) to compare point clouds in versatile settings.
On top of these discrete solvers, we also propose implementations of neural network approaches. Given an source/target pair of measure, they output a neural net network that seeks to approximation their optimal transport map.
OTT-JAX is led by a team of researchers at Apple, with past contributions from Google and Meta researchers, as well as academic partners, including TU München, Oxford, ENSAE/IP Paris, ENS Paris and the Hebrew University.
Installation
Install OTT-JAX from PyPI as:
pip install ott-jax
or with conda via conda-forge as:
conda install -c conda-forge ott-jax
What is optimal transport?
Optimal transport can be loosely described as the branch of mathematics and optimization that studies matching problems: given two families of points, and a cost function on pairs of points, find a "good" (low cost) way to associate bijectively to every point in the first family another in the second.
Such problems appear in all areas of science, are easy to describe, yet hard to solve. Indeed, while matching optimally two sets of $n$ points using a pairwise cost can be solved with the Hungarian algorithm, solving it costs an order of $O(n^3)$ operations, and lacks flexibility, since one may want to couple families of different sizes.
Optimal transport extends all of this, through faster algorithms (in $n^2$ or even linear in $n$) along with numerous generalizations that can help it handle weighted sets of different size, partial matchings, and even more evolved so-called quadratic matching problems.
In the simple toy example below, we compute the optimal coupling matrix between two point clouds sampled randomly (2D vectors, compared with the squared Euclidean distance):
Example
import jax
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers import linear
from ott.tools import plot
# sample two point clouds and their weights.
rngs = jax.random.split(jax.random.key(42), 4)
n, m, d = 6, 11, 2
x = jax.random.uniform(rngs[0], (n,d))
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,)) +.2
b = jax.random.uniform(rngs[3], (m,)) +.2
a, b = a / jnp.sum(a), b / jnp.sum(b)
# instantiate geometry object to compare point clouds.
geom = pointcloud.PointCloud(x, y)
# compute coupling using the Sinkhorn algorithm.
out = jax.jit(linear.solve)(geom,a,b)
# plot
plot.Plot()(out)
The call to solve(prob) above works out the optimal transport solution. The out object contains a transport matrix
(here of size $12\times 14$) that quantifies the association strength between each point of the first point cloud, to one or
more points from the second, as illustrated in the plot below. We provide more flexibility to define custom cost
functions, objectives, and solvers, as detailed in the full documentation. The last command displays the transport matrix by using a Plot object.
Citation
If you have found this work useful, please consider citing this reference:
@article{cuturi2022optimal,
title={Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein},
author={Cuturi, Marco and Meng-Papaxanthos, Laetitia and Tian, Yingtao and Bunne, Charlotte and
Davis, Geoff and Teboul, Olivier},
journal={arXiv preprint arXiv:2201.12324},
year={2022}
}
See also
The moscot package for OT analysis of multi-omics data uses OTT as a backbone.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ott_jax-0.6.0.tar.gz.
File metadata
- Download URL: ott_jax-0.6.0.tar.gz
- Upload date:
- Size: 233.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
86fda8aea68a6c29c989b1be2fc8bc7e5344e6353d3e486a3edbc36364b07483
|
|
| MD5 |
92bb3a26055df4e91fe19a9c881b5bc3
|
|
| BLAKE2b-256 |
8a46a45f351fcfaa456a42bc318d7ee8641ff4cef758aace295ae4a27459ac3a
|
Provenance
The following attestation bundles were made for ott_jax-0.6.0.tar.gz:
Publisher:
publish.yml on ott-jax/ott
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ott_jax-0.6.0.tar.gz -
Subject digest:
86fda8aea68a6c29c989b1be2fc8bc7e5344e6353d3e486a3edbc36364b07483 - Sigstore transparency entry: 667823333
- Sigstore integration time:
-
Permalink:
ott-jax/ott@5aa6da0fbd8fe9c255b749ac5fb49f3092860cc4 -
Branch / Tag:
refs/tags/0.6.0 - Owner: https://github.com/ott-jax
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5aa6da0fbd8fe9c255b749ac5fb49f3092860cc4 -
Trigger Event:
release
-
Statement type:
File details
Details for the file ott_jax-0.6.0-py3-none-any.whl.
File metadata
- Download URL: ott_jax-0.6.0-py3-none-any.whl
- Upload date:
- Size: 309.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
89d7ee73ae105a56385b66a43e9acd5ae070e66d991efc56047ab47cc756f3e6
|
|
| MD5 |
c7266954e3cb2cfa356dc127b9f0bfe8
|
|
| BLAKE2b-256 |
46780f66915250058498915b3944ac7eccb0da5ce1f5984664fbb6a448cf73db
|
Provenance
The following attestation bundles were made for ott_jax-0.6.0-py3-none-any.whl:
Publisher:
publish.yml on ott-jax/ott
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ott_jax-0.6.0-py3-none-any.whl -
Subject digest:
89d7ee73ae105a56385b66a43e9acd5ae070e66d991efc56047ab47cc756f3e6 - Sigstore transparency entry: 667823338
- Sigstore integration time:
-
Permalink:
ott-jax/ott@5aa6da0fbd8fe9c255b749ac5fb49f3092860cc4 -
Branch / Tag:
refs/tags/0.6.0 - Owner: https://github.com/ott-jax
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@5aa6da0fbd8fe9c255b749ac5fb49f3092860cc4 -
Trigger Event:
release
-
Statement type: