A JAX-based Dimension Reducer
Project description
A high-performance DImensionality REduction package with JAX
DiRe offers fast dimensionality reduction preserving the global dataset structure, with benchmarks showing competitive performance against UMAP and t-SNE. Built with JAX for efficient computation on CPUs and GPUs.
Quick start
Basic installation (JAX backend only):
pip install dire-jax
With utilities for benchmarking:
pip install dire-jax[utils]
Complete installation with utilities:
pip install dire-jax[all]
Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.
Example usage:
from dire_jax import DiRe
from sklearn.datasets import make_blobs
n_samples = 100_000
n_features = 1_000
n_centers = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)
reducer_blobs = DiRe(n_components=2,
n_neighbors=16,
init='pca',
max_iter_layout=32,
min_dist=1e-4,
spread=1.0,
cutoff=4.0,
n_sample_dirs=8,
sample_size=16,
neg_ratio=32,
verbose=False,)
_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)
The output should look similar to
Documentation
Please refer to the DiRe API documentation for more instructions.
Project documentation structure:
/docs/- API documentation and architecture details/benchmarking/- Performance benchmarks and scaling results/examples/- Example usage and demos/tests/- Test suite and benchmarking notebooks
Working paper
Our working paper is available on the arXiv.
Also, check out the Jupyter notebook with benchmarking results.
Performance Characteristics
DiRe-JAX is optimized for small-medium datasets (<50K points) with excellent CPU performance and GPU acceleration via JAX. Features fully vectorized computation with JIT compilation for optimal performance.
Benchmarking and utilities
For benchmarking utilities and quality metrics:
pip install dire-jax[utils]
This provides access to dimensionality reduction quality metrics and benchmarking routines. Some utilities use external packages for persistent homology computations which may increase runtime.
Contributing
Please follow the contibuting guide. Thanks!
Acknowledgement
This work is supported by the Google Cloud Research Award number GCP19980904.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dire_jax-0.2.1.tar.gz.
File metadata
- Download URL: dire_jax-0.2.1.tar.gz
- Upload date:
- Size: 39.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
68be1c7860e21624f94030d0572ad8b8e5f1649d47bd4304cb0bc30ff1c24bdf
|
|
| MD5 |
a778bc1f4eaa8eb64c8adde6e318ac05
|
|
| BLAKE2b-256 |
bf892d0a6fa9ad2ba540611218c498ac3ddd48760a762b0a2411e3717440b5b3
|
File details
Details for the file dire_jax-0.2.1-py3-none-any.whl.
File metadata
- Download URL: dire_jax-0.2.1-py3-none-any.whl
- Upload date:
- Size: 40.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cb364729f3c77b1aee75f00c56e5dbde7912c22ea4695f6cf6dc27521103604d
|
|
| MD5 |
5ad6216af6c058634d5528d1835f91a5
|
|
| BLAKE2b-256 |
518a527e62d7598b6a6e00000432034ca6a642c2942b8087d404f761077a62f2
|