Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install wassersteinwormhole
And running the Womrhole on your own set of point-clouds is as simple as:
from wassersteinwormhole import Wormhole
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.2.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | bf62f001fc69bee641d2ecc59914a777a0a4a8731cbf448a4b5b3697faef5eb3 |
|
MD5 | 07fc0884882d408413c9bc668f2bdcbb |
|
BLAKE2b-256 | f4027d63775228102112bec80180f7ccad06183a169e72582e765a1a2f3d1c2a |
Close
Hashes for wassersteinwormhole-0.2.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 69c08d17c3f00b2e3a162526917372b9583a5d7ca5dcde99b39833c3d6342ee9 |
|
MD5 | 845959d0ecb71052a81e4da8475abff7 |
|
BLAKE2b-256 | d6f709ceff0c2ac93a14e564eb2f796df1e2fda844f479fa7071dcdeb7011b95 |