OTT: Optimal Transport Tools in Jax.
Project description
Optimal Transport Tools (OTT), A toolbox for all things Wasserstein.
See full documentation for detailed info on the toolbox.
Most of OTT is, for now, supported by a sturdy, versatile and efficient implementation of the Sinkhorn algorithm that takes advantage of JAX features, such as JIT, auto-vectorization and implicit differentiation.
A typical OT problem has two ingredients: a pair of weight vectors a
and b
(one for each measure), with a ground cost matrix that is either directly given, or derived as the pairwise evaluation of a cost function on pairs of points taken from two measures. The main design choice in OTT comes from encapsulating the cost in a Geometry
object, and bundle it with a few useful operations (notably kernel applications). The most common geometry is that of two clouds of vectors compared with the squared Euclidean distance, as illustrated in the example below:
Example
import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings via Sinkhorn algorithm.
ot = transport.Transport(x, y, a=a, b=b)
P = ot.matrix
The call to sinkhorn
above works out the optimal transport solution by storing its output. The transport matrix can be instantiated using those optimal solutions and the Geometry
again. That transoprt matrix links each point from the first point cloud to one or more points from the second, as illustrated below.
To be more precise, the sinkhorn
algorithm operates on the Geometry
,
taking into account weights a
and b
, to solve the OT problem, produce a named tuple that contains two optimal dual potentials f
and g
(vectors of the same size as a
and b
), the objective reg_ot_cost
and a log of the errors
of the algorithm as it converges, and a converged
flag.
Overall description of source code
Currently implements the following classes and functions:
-
In the geometry folder,
-
The
CostFn
class in costs.py and its descendants define cost functions between points. Two simple costs are currently provided,Euclidean
between vectors, andBures
, between a pair of mean vector and covariance (p.d.) matrix. -
The
Geometry
class in geometry.py and its descendants describe a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).-
In its generic
Geometry
implementation, as in geometry.py, an object can be initialized with either acost_matrix
along with anepsilon
regularization parameter (or scheduler), or with akernel_matrix
. -
If one wishes to compute OT between two weighted point clouds and endowed with a given cost function (e.g. Euclidean) , the
PointCloud
class in pointcloud.py can be used to define the corresponding kernel . When the number of these points grows very large, this geometry can be instantiated with anonline=True
parameter, to avoid storing the kernel matrix and choose instead to recompute the matrix on the fly at each application. -
Simlarly, if all measures to be considered are supported on a separable grid (e.g. ), and the cost is separable along all axis, i.e. the cost between two points on that grid is equal to the sum of (possibly different) cost functions evaluated on each of the pairs of coordinates, then the application of the kernel is much simplified, both in log space or on the histograms themselves. This particular case is exploited in the
Grid
geometry in grid.py which can be instantiated as a hypercube using agrid_size
parameter, or directly through grid locations inx
.
-
-
-
In the core folder,
-
The
sinkhorn
function in sinkhorn.py runs the Sinkhorn algorithm, with the aim of solving approximately one or various optimal transport problems in parallel. An OT problem is defined by aGeometry
object, and a pair (or batch thereof) of histograms. The function's outputs are stored in aSinkhornOutput
named t-uple, containing potentials, regularized OT cost, sequence of errors and a convergence flag. Such outputs (with the exception of errors and convergence flag) can be differentiated w.r.t. any of the three inputs(Geometry, a, b)
either through backprop or implicit differentiation of the optimality conditions of the optimal potentialsf
andg
. -
In discrete_barycenter.py: implementation of discrete Wasserstein barycenters : given histograms all supported on the same
Geometry
, compute a barycenter of theses measures, using an algorithm by Janati et al. (2020) -
In gromov_wasserstein.py: implementation of the Gromov-Wasserstein metric between measured-metric spaces, here encoded as a pair of
Geometry
objects along with weightsa
andb
.
-
-
In the tools folder,
-
In soft_sort.py: implementation of soft-sorting operators, notably soft-quantile transforms
-
The
sinkhorn_divergence
function in sinkhorn_divergence.py, implements the unbalanced formulation of the Sinkhorn divergence, a variant of the Wasserstein distance that uses regularization and is computed by centering the output ofsinkhorn
when comparing two measures. -
The
Transport
class in sinkhorn_divergence.py, provides a simple wrapper to thesinkhorn
function defined above when the user is primarily interested in computing and storing an OT matrix.
-
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file ott-jax-0.2.0.tar.gz
.
File metadata
- Download URL: ott-jax-0.2.0.tar.gz
- Upload date:
- Size: 74.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 122f5690a328c9289ea227e7fbef0f0767f53b3c0914c84954ccf9691f36eb7a |
|
MD5 | 502d21e147d91603297b05f8b767d642 |
|
BLAKE2b-256 | cdd5989a32ca24f5236abbab77099518873617f0e48243a6ffc710ca5eb91920 |
File details
Details for the file ott_jax-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: ott_jax-0.2.0-py3-none-any.whl
- Upload date:
- Size: 94.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.1 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 748615d06a7f8555265cce98f009ee391a6aca06bb45043ec9f180a8f635c817 |
|
MD5 | 3602fc4a2a1130742d4062c0a2d7aeca |
|
BLAKE2b-256 | 959f876402ff6c46351a244d0503dde73eef61a5068170477a342234637293c0 |