Skip to main content

UMAP, but optimized with jax.

Project description

umapjax

Tests

UMAP, but accelerated. (Experimental implementation)

umapjax inherits the API of umap-learn. The UmapJax class is a drop-in replacement for umap.UMAP, with a few key differences:

  1. umapjax does not support densmap.
  2. umapjax does not support output_metric other than euclidean.

Note: umapjax does not fully replicate umap-learn and care should be used when interpreting results.

This package implements the following backends (despite being named umapjax):

  1. torch (PyTorch)
  2. mx (MLX)
  3. jax (JAX)

Getting started

import umapjax

layout_backend: Literal["jax", "mx", "torch"] = "jax"
spectral_backend: Literal["jax", "scipy", "torch"] = "scipy"
batch_size: int | None = None # Defaults to X.shape[0]

model = umapjax.UmapJax(
    n_neighbors=15,
    layout_backend=layout_backend,
    spectral_backend=spectral_backend
)
embedding = model.fit_transform(X)

If the optimization is slow, try increasing the batch size as a multiple of X.shape[0]. All backends will automatically use accelerated hardware if available.

If using "torch", you can set umapjax.layouts.torch.TORCH_DEVICE and umapjax.spectral.torch.TORCH_DEVICE to control the default device used for the layout and spectral embedding, respectively.

Implementation details

The implementaion used in umapjax is very similar to the one used in umap-learn; however, rather than a single step updating one single point, we update a set of points in parallel using jax. The gradients of the points are weighted by edge weights, which control sampling frequencies in the original algorithm. If results look strange, try changing n_epochs or batch_size. The batch_size argument can also be used to control acceleration on GPUs/TPUs.

Installation

You need to have Python 3.11 or newer installed on your system. If you don't have Python installed, we recommend installing uv.

There are several alternative options to install umapjax:

  1. Install the latest release of umapjax from PyPI with a preferred backend:
pip install "umapjax[jax,mlx,torch]"
  1. Install the latest development version:
pip install "umapjax[jax,mlx,torch] @ git+https://github.com/adamgayoso/umapjax.git@main"

Contact

If you found a bug, please use the issue tracker.

Citation

t.b.a

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

umapjax-0.1.0.tar.gz (127.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

umapjax-0.1.0-py3-none-any.whl (26.7 kB view details)

Uploaded Python 3

File details

Details for the file umapjax-0.1.0.tar.gz.

File metadata

  • Download URL: umapjax-0.1.0.tar.gz
  • Upload date:
  • Size: 127.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1fc8b994bbaf155d7d08d27d3ee79ef6189d8ab5e46e8b24d3664cb331d07471
MD5 55687bae60583e8775d39f2c57c569ab
BLAKE2b-256 63b43ead1e48a4df5908ddda2c1cf5ae91fcb36b3cc204df048ab669fbc53c36

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.1.0.tar.gz:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file umapjax-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: umapjax-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 26.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c06e5ba443d8a28f10c4908355839349402a23c69ef492a1be7b3cc876e999eb
MD5 0ba8cd71aac109de8f858922ef0d78e4
BLAKE2b-256 00bcdc42366d5f3d13cff3b2513c94524f5393984d4b9806a27fdcfaba168dbd

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.1.0-py3-none-any.whl:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page