Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install wassersteinwormhole
And running the Womrhole on your own set of point-clouds is as simple as:
from wassersteinwormhole import Wormhole
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.1.6.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc206c8f2e7a8e39e9c08b265a9ae7573ed8dda515ac365c561daa795739a6b2 |
|
MD5 | 0ef6f35c580ca85987419b6d1db88e3c |
|
BLAKE2b-256 | 87c2d15ea02e2a8eff3d78b353dfef7737d5a45a4549d32fd196127a2a8205fe |
Close
Hashes for wassersteinwormhole-0.1.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce9199a3cbc33bb56e35319521fb49275989c9bdf396f0bf906d468a71b38d8c |
|
MD5 | 32009c18c7bfdd5e72f00cb3031ca3ab |
|
BLAKE2b-256 | 334c390f2c85758987e0a0b0c5852854a30c7e85b886962285f26c63d67c7106 |