Transformer based embeddings for Wasserstein Distances
Project description
WassersteinWormhole for Python3
Embedding point-clouds by presering Wasserstein distancse with the Wormhole.
This implementation is written in Python3 and relies on FLAX, JAX, & JAX-OTT.
To install JAX, simply run the command:
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
And to install WassersteinWormhole along with the rest of the requirements:
pip install WassersteinWormhole
And running the Womrhole on your own set of point-clouds is as simple as:
WormholeModel = Wormhole(point_clouds = point_clouds)
WormholeModel.train()
Embeddings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.masks)
For more details, follow tutorial at https://github.com/dpeerlab/WassersteinWormhole.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for wassersteinwormhole-0.1.3.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 76f588419f2412d116896707caa245fff02360996407fed1c655a7b4ac92439b |
|
MD5 | cd56d47c483e5b2bd6fdf2ab36ca392f |
|
BLAKE2b-256 | b447699ef0a05c8ad9ad87ec61072a118f0b71dc70896eeb6cda90725fd3e7ca |
Close
Hashes for wassersteinwormhole-0.1.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a231bfcd67e860306a8f7591be343f97e18b1d69e484655c67b116cfd1b9f8d8 |
|
MD5 | caa4b92bbf3c490923a03a13649b3bf2 |
|
BLAKE2b-256 | 3fc1780d28efbb1e7e160e202bedc29e597986e2eb327c7b6bbb0239574d2940 |