Skip to main content

OTT: Optimal Transport Tools in Jax.

Project description

Optimal Transport Tools (OTT)

See full documentation for detailed info.

OTT is a JAX toolbox that bundles a few utilities to solve optimal transport problems. These tools can help you compare and match two weighted point clouds (or histograms, measures, etc.), given a cost (e.g. a distance) between single points.

Most of OTT is, for now, supported by a sturdy, versatile and efficient implementation of the Sinkhorn algorithm that takes advantage of JAX features, such as JIT, auto-vectorization and implicit differentiation.

An typical OT problem has two main ingredients: a pair of weight vectors a and b (one for each measure) and a ground cost evaluated on the pair of measures, cast usually as a pairwise cost matrix. OTT encapsulates the ground cost (and several operations associated with it) in a Geometry object. The most common geometry is that of two point clouds compared with the squared Euclidean distance, as used in the example below:

Example

import jax
from ott.geometry import pointcloud
from ott.core import sinkhorn

# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b  = a / np.sum(a), b / np.sum(b)

# Computes the couplings via Sinkhorn algorithm.
geom = pointcloud.PointCloud(x,y)
out = sinkhorn.sinkhorn(geom, a, b)
P = geom.transport_from_potentials(out.f, out.g)

One can then plot the transport linking each point from the first point cloud to one or more points from the second.

obtained coupling

As can be seen above, the sinkhorn algorithm will operate on that Geometry, taking into account weights a and b, to produce a named tuple that contains among other things two potentials f and g (vectors of the same respective size as a and b), as well as reg_ot_cost, the objective of the regularized OT problem.

Overall description of source code

Currently implements the following classes and functions:

  • In the geometry folder,

    • The CostFn class in costs.py and its descendants define cost functions between points. Two simple costs are currently provided, Euclidean between vectors, and Bures, between a pair of mean vector and covariance (p.d.) matrix.

    • The Geometry class in geometry.py and its descendants describe a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).

      • In its generic Geometry implementation, as in geometry.py, an object can be initialized with either a cost_matrix along with an epsilon regularization parameter (or scheduler), or with a kernel_matrix.

      • If one wishes to compute OT between two weighted point clouds and endowed with a given cost function (e.g. Euclidean) , the PointCloud class in pointcloud.py can be used to define the corresponding kernel . When the number of these points grows very large, this geometry can be instantiated with an online=True parameter, to avoid storing the kernel matrix and choose instead to recompute the matrix on the fly at each application.

      • Simlarly, if all measures to be considered are supported on a separable grid (e.g. ), and the cost is separable along all axis, i.e. the cost between two points on that grid is equal to the sum of (possibly different) cost functions evaluated on each of the pairs of coordinates, then the application of the kernel is much simplified, both in log space or on the histograms themselves. This particular case is exploited in the Grid geometry in grid.py which can be instantiated as a hypercube using a grid_size parameter, or directly through grid locations in x.

  • In the core folder,

    • The sinkhorn function in sinkhorn.py runs the Sinkhorn algorithm, with the aim of solving approximately one or various optimal transport problems in parallel. An OT problem is defined by a Geometry object, and a pair (or batch thereof) of histograms. The function's outputs are stored in a SinkhornOutput named t-uple, containing potentials, regularized OT cost, sequence of errors and a convergence flag. Such outputs (with the exception of errors and convergence flag) can be differentiated w.r.t. any of the three inputs (Geometry, a, b) either through backprop or implicit differentiation of the optimality conditions of the optimal potentials f and g.

    • In discrete_barycenter.py: implementation of discrete Wasserstein barycenters : given histograms all supported on the same Geometry, compute a barycenter of theses measures, using an algorithm by Janati et al. (2020)

  • In the tools folder,

    • In soft_sort.py: implementation of soft-sorting operators .

    • The sinkhorn_divergence function in sinkhorn_divergence.py, implements the Sinkhorn divergence, a variant of the Wasserstein distance that uses regularization and is computed by centering the output of sinkhorn when comparing two measures.

    • The Transport class in sinkhorn_divergence.py, provides a simple wrapper to the sinkhorn function defined above when the user is primarily interested in computing and storing an OT matrix.

Disclaimer: this is not an official Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ott-jax-0.1.1.tar.gz (36.7 kB view details)

Uploaded Source

Built Distribution

ott_jax-0.1.1-py3-none-any.whl (45.8 kB view details)

Uploaded Python 3

File details

Details for the file ott-jax-0.1.1.tar.gz.

File metadata

  • Download URL: ott-jax-0.1.1.tar.gz
  • Upload date:
  • Size: 36.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/54.0.0 requests-toolbelt/0.9.1 tqdm/4.58.0 CPython/3.9.1

File hashes

Hashes for ott-jax-0.1.1.tar.gz
Algorithm Hash digest
SHA256 3307fcc73b76993cde22ec5b493d0d1990e1e827fb3a211a007bda34229e1d43
MD5 06c0237b11f4ebf86e9a40f0e532759f
BLAKE2b-256 d0b6a436b7ad412e7a5ce729fac1c0da5b57e8c445d32d94f82ce0b24ab110e6

See more details on using hashes here.

File details

Details for the file ott_jax-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: ott_jax-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 45.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/54.0.0 requests-toolbelt/0.9.1 tqdm/4.58.0 CPython/3.9.1

File hashes

Hashes for ott_jax-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3aa67ce23197994f97a8b9196e30c5c55041a44249de4760f2cbb5d5abfdd4ca
MD5 6bb390f6ae421f0d9d2b9831ae55c17d
BLAKE2b-256 63bf2762ce651284d43d7fbbc5adea369be3c1f223f0e21216dfb7dc644e88c5

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page