Skip to main content

OTT: Optimal Transport Tools in Jax.

Project description

Tests

logo

Optimal Transport Tools (OTT), A toolbox for all things Wasserstein.

See full documentation for detailed info on the toolbox.

The goal of OTT is to provide sturdy, versatile and efficient optimal transport solvers, taking advantage of JAX features, such as JIT, auto-vectorization and implicit differentiation.

A typical OT problem has two ingredients: a pair of weight vectors a and b (one for each measure), with a ground cost matrix that is either directly given, or derived as the pairwise evaluation of a cost function on pairs of points taken from two measures. The main design choice in OTT comes from encapsulating the cost in a Geometry object, and bundle it with a few useful operations (notably kernel applications). The most common geometry is that of two clouds of vectors compared with the squared Euclidean distance, as illustrated in the example below:

Example

import jax
import jax.numpy as jnp
from ott.tools import transport
# Samples two point clouds and their weights.
rngs = jax.random.split(jax.random.PRNGKey(0),4)
n, m, d = 12, 14, 2
x = jax.random.normal(rngs[0], (n,d)) + 1
y = jax.random.uniform(rngs[1], (m,d))
a = jax.random.uniform(rngs[2], (n,))
b = jax.random.uniform(rngs[3], (m,))
a, b = a / jnp.sum(a), b / jnp.sum(b)
# Computes the couplings via Sinkhorn algorithm.
ot = transport.Transport(x, y, a=a, b=b)
P = ot.matrix

The call to sinkhorn above works out the optimal transport solution by storing its output. The transport matrix can be instantiated using those optimal solutions and the Geometry again. That transoprt matrix links each point from the first point cloud to one or more points from the second, as illustrated below.

obtained coupling

To be more precise, the sinkhorn algorithm operates on the Geometry, taking into account weights a and b, to solve the OT problem, produce a named tuple that contains two optimal dual potentials f and g (vectors of the same size as a and b), the objective reg_ot_cost and a log of the errors of the algorithm as it converges, and a converged flag.

Overall description of source code

Currently implements the following classes and functions:

  • In the geometry folder,

    • The CostFn class in costs.py and its descendants define cost functions between points. Two simple costs are currently provided, Euclidean between vectors, and Bures, between a pair of mean vector and covariance (p.d.) matrix.

    • The Geometry class in geometry.py and its descendants describe a cost structure between two measures. That cost structure is accessed through various member functions, either used when running the Sinkhorn algorithm (typically kernel multiplications, or log-sum-exp row/column-wise application) or after (to apply the OT matrix to a vector).

      • In its generic Geometry implementation, as in geometry.py, an object can be initialized with either a cost_matrix along with an epsilon regularization parameter (or scheduler), or with a kernel_matrix.

      • If one wishes to compute OT between two weighted point clouds and endowed with a given cost function (e.g. Euclidean) , the PointCloud class in pointcloud.py can be used to define the corresponding kernel . When the number of these points grows very large, this geometry can be instantiated with an online=True parameter, to avoid storing the kernel matrix and choose instead to recompute the matrix on the fly at each application.

      • Simlarly, if all measures to be considered are supported on a separable grid (e.g. ), and the cost is separable along all axis, i.e. the cost between two points on that grid is equal to the sum of (possibly different) cost functions evaluated on each of the pairs of coordinates, then the application of the kernel is much simplified, both in log space or on the histograms themselves. This particular case is exploited in the Grid geometry in grid.py which can be instantiated as a hypercube using a grid_size parameter, or directly through grid locations in x.

      • LRCGeometry, low-rank cost geometries, of which a PointCloud endowed with a squared-Euclidean distance is a particular example, can efficiently carry apply their cost to another matrix. This is leveraged in particular in the low-rank Sinkhorn (and Gromov-Wasserstein) solvers.

  • In the core folder,

    • The sinkhorn function in sinkhorn.py is a wrapper around the Sinkhorn solver class, running the Sinkhorn algorithm, with the aim of solving approximately one or various optimal transport problems in parallel. An OT problem is defined by a Geometry object, and a pair (or batch thereof) of histograms. The function's outputs are stored in a SinkhornOutput named t-uple, containing potentials, regularized OT cost, sequence of errors and a convergence flag. Such outputs (with the exception of errors and convergence flag) can be differentiated w.r.t. any of the three inputs (Geometry, a, b) either through backprop or implicit differentiation of the optimality conditions of the optimal potentials f and g.

    • A later addition in sinkhorn_lr.py is focused on the LRSinkhorn solver class, which is able to solve OT problems at larger scales using an explicit factorization of couplings as being low-rank.

    • In discrete_barycenter.py: implementation of discrete Wasserstein barycenters : given histograms all supported on the same Geometry, compute a barycenter of theses measures, using an algorithm by Janati et al. (2020)

    • In gromov_wasserstein.py: implementation of two Gromov-Wasserstein solvers (both entropy-regularized and low-rank) to compare two measured-metric spaces, here encoded as a pair of Geometry objects, geom_xx, geom_xy along with weights a and b. Additional options include using a fused term by specifying geom_xy.

  • In the tools folder,

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ott-jax-0.2.1.tar.gz (75.5 kB view details)

Uploaded Source

Built Distribution

ott_jax-0.2.1-py3-none-any.whl (95.2 kB view details)

Uploaded Python 3

File details

Details for the file ott-jax-0.2.1.tar.gz.

File metadata

  • Download URL: ott-jax-0.2.1.tar.gz
  • Upload date:
  • Size: 75.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.10.2

File hashes

Hashes for ott-jax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 b549bd1648c2ef90e7668cacab9beb16f9a5d0b9f202ff2e70b6a362fbf7b60b
MD5 02ba1fc40eb3fe0ab017da3d79d39754
BLAKE2b-256 471b7b6996ccbc94edfa92ea0352260abac380f027bd4f783d1b8f64cfc4e40b

See more details on using hashes here.

File details

Details for the file ott_jax-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: ott_jax-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 95.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.8.0 pkginfo/1.8.2 readme-renderer/32.0 requests/2.27.1 requests-toolbelt/0.9.1 urllib3/1.26.8 tqdm/4.62.3 importlib-metadata/4.10.1 keyring/23.5.0 rfc3986/2.0.0 colorama/0.4.4 CPython/3.10.2

File hashes

Hashes for ott_jax-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9ba76119055b4590e28eeb06de4a72e958011620c61ce074cc61e8ab5214ca71
MD5 998ce070f50462bbfb5aaefa906b7092
BLAKE2b-256 ca2cc4d1dfd5e5a4d0aa37adc42bb2634965ed20eb4ffa9f8c7d76292fcbdc19

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page