Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install wassersteinwormhole
And running the Womrhole on your own set of point-clouds is as simple as:
from wassersteinwormhole import Wormhole
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://wasserstienwormhole.readthedocs.io.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.3.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | c5bbd93b34f911116e60429b0f28c81043c223f32247a83552186b8b68b41a66 |
|
MD5 | 7e3e3956cf60d5d8746ec252e0c3b5fc |
|
BLAKE2b-256 | 6ca06928c9228ee4d16cf8df796f74acfe85fb92cb61d372b6e66eaae8698a2c |
Close
Hashes for wassersteinwormhole-0.3.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a54ef97c818b6ff7ceca81a3d518168176267587f4ecce20c942581a8bec832a |
|
MD5 | a1e36cfaaed108650f7ca197ffd0d175 |
|
BLAKE2b-256 | b419620b9c66e1676bcbc6e9ec686fd64f99d0609740d37bd081de562c2a6b09 |