Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install wassersteinwormhole
And running the Womrhole on your own set of point-clouds is as simple as:
from wassersteinwormhole import Wormhole
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.2.0.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 98850da92c518eabf7ae22ff10fcccfbfa2b8eee6e5f84ccef62cd2ae7981f11 |
|
MD5 | cede9f9c2bdc9832b8eb450210f93c0d |
|
BLAKE2b-256 | 3f5a94bfb716942f8423e8a560555a5bf250ad1efd6343f61fbb88f9ac898bdf |
Close
Hashes for wassersteinwormhole-0.2.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 992c526e5a5c38936b2634c23204e0a692b3eeeda1b75aba6846566ea8d19f9d |
|
MD5 | 5083180a040fd59e7bef8a4490861a3a |
|
BLAKE2b-256 | 40a2bc233bdae49e2a299fcc87c5ddf77dbb708c483ae941298c3f96bbb56ded |