Skip to main content

Transformer based embeddings for Wasserstein Distances

Project description

WassersteinWormhole

Embedding point-clouds by presering Wasserstein distancse with the Wormhole.

This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.

To install JAX, simply run the command:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

And to install WassersteinWormhole along with the rest of the requirements:

pip install wassersteinwormhole

And running the Womrhole on your own set of point-clouds is as simple as:

from wassersteinwormhole import Wormhole 
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)

For more details, follow tutorial at https://wasserstienwormhole.readthedocs.io.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

wassersteinwormhole-0.3.0.tar.gz (9.6 kB view hashes)

Uploaded Source

Built Distribution

wassersteinwormhole-0.3.0-py3-none-any.whl (11.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page