A JAX-based Dimension Reducer
Project description
DiRe - JAX
A new DImensionality REduction package written in JAX
We offer a new dimension reduction tool called DiRe - JAX that is benchmarked against the existing approaches: UMAP (original and Rapids.AI versions), and tSNE (Rapids.AI version)
Quick start
Do either
pip install dire-jax
if you need to install the main DiRe class only
pip install dire-jax[utils]
if you also need the benchmarking utilities.
Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.
Then, do the imports
from dire_jax import DiRe
from dire_jax.dire_utils import display_layout
and afterwards, for example, try this:
from sklearn.datasets import make_blobs
n_samples = 100_000
n_features = 1_000
n_centers = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)
reducer_blobs = DiRe(dimension=2,
n_neighbors=16,
init_embedding_type='pca',
max_iter_layout=32,
min_dist=1e-4,
spread=1.0,
cutoff=4.0,
n_sample_dirs=8,
sample_size=16,
neg_ratio=32,
verbose=False,)
_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)
The output should look similar to
Documentation
Please refer to the DiRe API documentation for more instructions.
Working paper
Our working paper is available in the repository.
Also, check out the Jupyter notebook with benchmarking results.
Benchmarking and utilities
In order to run the Jupyter notebook in the ./tests folder, you need to install some extras:
pip install dire-jax[utils]
This installation will give you access to the utilities (metrics and benchmarking routines) that are specifically implemented to be used together with DiRe. However, some of them rely on external packages (especially for persistent homology computations) that may have longer runtimes.
Contributing
Please follow the contibuting guide. Thanks!
Acknowledgement
This work is supported by the Google Cloud Research Award number GCP19980904.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dire_jax-0.1.0.tar.gz.
File metadata
- Download URL: dire_jax-0.1.0.tar.gz
- Upload date:
- Size: 7.3 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
713fd56383e9072c8d7379f134c6ecade5b18028db19a0b28f1bfce82e9cb7c2
|
|
| MD5 |
1a69189270c487bc2e4d2e60f9f16217
|
|
| BLAKE2b-256 |
9e7aa4d44cb1ac27e11caadc288b26d62e7f0a8773d0bcb45ac4970dc7f784b3
|
File details
Details for the file dire_jax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: dire_jax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 7.3 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.8.10
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e13fc8a2fbc3de6708202320a7c8ce0f0535c558b96044a2bf7ada654148da53
|
|
| MD5 |
e9f631cc936f4a44fdb70580e0fc1ffb
|
|
| BLAKE2b-256 |
7d610bfb50092d764f98240e59089516fffff57ea0f1299bc6196debd4a227b1
|