Skip to main content

A JAX-based Dimension Reducer

Project description

DiRe - JAX

License Python 3.8+ PyPI

CI Docs Docs

A new DImensionality REduction package written in JAX

We offer a new dimension reduction tool called DiRe - JAX that is benchmarked against the existing approaches: UMAP (original and Rapids.AI versions), and tSNE (Rapids.AI version)

Quick start

Do either

pip install dire-jax

if you need to install the main DiRe class only

pip install dire-jax[utils]

if you also need the benchmarking utilities.

Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.

Then, do the imports

from dire_jax import DiRe
from dire_jax.dire_utils import display_layout

and afterwards, for example, try this:

from sklearn.datasets import make_blobs

n_samples  = 100_000
n_features = 1_000
n_centers  = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)

reducer_blobs = DiRe(dimension=2,
                     n_neighbors=16,
                     init_embedding_type='pca',
                     max_iter_layout=32,
                     min_dist=1e-4,
                     spread=1.0,
                     cutoff=4.0,
                     n_sample_dirs=8,
                     sample_size=16,
                     neg_ratio=32,
                     verbose=False,)

_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)

The output should look similar to

12 blobs with 100k points in 1k dimensions embedded in dimension 2

Documentation

Please refer to the DiRe API documentation for more instructions.

Working paper

Our working paper is available in the repository. View PDF

Also, check out the Jupyter notebook with benchmarking results. Open in Colab

Benchmarking and utilities

In order to run the Jupyter notebook in the ./tests folder, you need to install some extras:

pip install dire-jax[utils]

This installation will give you access to the utilities (metrics and benchmarking routines) that are specifically implemented to be used together with DiRe. However, some of them rely on external packages (especially for persistent homology computations) that may have longer runtimes.

Contributing

Please follow the contibuting guide. Thanks!

Acknowledgement

This work is supported by the Google Cloud Research Award number GCP19980904.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dire_jax-0.1.0.tar.gz (7.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dire_jax-0.1.0-py3-none-any.whl (7.3 MB view details)

Uploaded Python 3

File details

Details for the file dire_jax-0.1.0.tar.gz.

File metadata

  • Download URL: dire_jax-0.1.0.tar.gz
  • Upload date:
  • Size: 7.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.10

File hashes

Hashes for dire_jax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 713fd56383e9072c8d7379f134c6ecade5b18028db19a0b28f1bfce82e9cb7c2
MD5 1a69189270c487bc2e4d2e60f9f16217
BLAKE2b-256 9e7aa4d44cb1ac27e11caadc288b26d62e7f0a8773d0bcb45ac4970dc7f784b3

See more details on using hashes here.

File details

Details for the file dire_jax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: dire_jax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 7.3 MB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.8.10

File hashes

Hashes for dire_jax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e13fc8a2fbc3de6708202320a7c8ce0f0535c558b96044a2bf7ada654148da53
MD5 e9f631cc936f4a44fdb70580e0fc1ffb
BLAKE2b-256 7d610bfb50092d764f98240e59089516fffff57ea0f1299bc6196debd4a227b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page