Skip to main content

A JAX-based Dimension Reducer

Project description

DiRe-JAX logo

License Python 3.8+ PyPI DOI badge

PyPI Downloads CI Docs Docs Live

A high-performance DImensionality REduction package with JAX

DiRe offers fast dimensionality reduction preserving the global dataset structure, with benchmarks showing competitive performance against UMAP and t-SNE. Built with JAX for efficient computation on CPUs and GPUs.

Quick start

Basic installation (JAX backend only):

pip install dire-jax

With utilities for benchmarking:

pip install dire-jax[utils]

Complete installation with utilities:

pip install dire-jax[all]

Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.

Example usage:

from dire_jax import DiRe
from sklearn.datasets import make_blobs
n_samples  = 100_000
n_features = 1_000
n_centers  = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)

reducer_blobs = DiRe(n_components=2,
                     n_neighbors=16,
                     init='pca',
                     max_iter_layout=32,
                     min_dist=1e-4,
                     spread=1.0,
                     cutoff=4.0,
                     n_sample_dirs=8,
                     sample_size=16,
                     neg_ratio=32,
                     verbose=False,)

_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)

The output should look similar to

12 blobs with 100k points in 1k dimensions embedded in dimension 2

Documentation

Please refer to the DiRe API documentation for more instructions.

Project documentation structure:

  • /docs/ - API documentation and architecture details
  • /benchmarking/ - Performance benchmarks and scaling results
  • /examples/ - Example usage and demos
  • /tests/ - Test suite and benchmarking notebooks

Working paper

Our working paper is available on the arXiv. Paper

Also, check out the Jupyter notebook with benchmarking results. Open in Colab

Performance Characteristics

DiRe-JAX is optimized for small-medium datasets (<50K points) with excellent CPU performance and GPU acceleration via JAX. Features fully vectorized computation with JIT compilation for optimal performance.

Benchmarking and utilities

For benchmarking utilities and quality metrics:

pip install dire-jax[utils]

This provides access to dimensionality reduction quality metrics and benchmarking routines. Some utilities use external packages for persistent homology computations which may increase runtime.

Contributing

Please follow the contibuting guide. Thanks!

Acknowledgement

This work is supported by the Google Cloud Research Award number GCP19980904.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dire_jax-0.2.1.tar.gz (39.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dire_jax-0.2.1-py3-none-any.whl (40.6 kB view details)

Uploaded Python 3

File details

Details for the file dire_jax-0.2.1.tar.gz.

File metadata

  • Download URL: dire_jax-0.2.1.tar.gz
  • Upload date:
  • Size: 39.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for dire_jax-0.2.1.tar.gz
Algorithm Hash digest
SHA256 68be1c7860e21624f94030d0572ad8b8e5f1649d47bd4304cb0bc30ff1c24bdf
MD5 a778bc1f4eaa8eb64c8adde6e318ac05
BLAKE2b-256 bf892d0a6fa9ad2ba540611218c498ac3ddd48760a762b0a2411e3717440b5b3

See more details on using hashes here.

File details

Details for the file dire_jax-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: dire_jax-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 40.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.7

File hashes

Hashes for dire_jax-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 cb364729f3c77b1aee75f00c56e5dbde7912c22ea4695f6cf6dc27521103604d
MD5 5ad6216af6c058634d5528d1835f91a5
BLAKE2b-256 518a527e62d7598b6a6e00000432034ca6a642c2942b8087d404f761077a62f2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page