Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install WassersteinWormhole
And running the Womrhole on your own set of point-clouds is as simple as:
from WassersteinWormhole import Wormhole
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.1.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1a938659059e80cd4b2b19168038a850045b5a92b18366750139e01787b1d3ed |
|
MD5 | 736e019f386c3afd8f66575e64fad2dd |
|
BLAKE2b-256 | 1f1309447d81057203266b8a976f49a76e78d7482f30507de9ef7680f7ecdf27 |
Close
Hashes for wassersteinwormhole-0.1.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a6d5aed78efaccf4b46d83d677344408e79082d4f5f58c11a794a79d18d634a9 |
|
MD5 | acd856b236813b825b7f056bc3f231d9 |
|
BLAKE2b-256 | f8e4cc56ef71814bb0666ba426a7a40a60de2053c8236feeeefc1917e01528b1 |