Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install WassersteinWormhole
And running the Womrhole on your own set of point-clouds is as simple as:
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.1.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9e12d02a2f12ef13e31a382548037fffa54ce8391267ef41b7e2f814dc9bae95 |
|
MD5 | 99949c865852593cff2bd8b7fe1f5d2a |
|
BLAKE2b-256 | f8224faf758d26e3ba6671ecf86c473185680e60ec8763a84c926187bb40720b |
Close
Hashes for wassersteinwormhole-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 226d8306fe4430aa31c99b1233af4bfd3b973bb5228730332d3d3f506b7c8ca6 |
|
MD5 | 95e3e2f558131860437bd7f9316af0cc |
|
BLAKE2b-256 | f348d7d9e03a343361483d03c69384cf6acd6a9e4c476799f6b21cff4a61e7aa |