Skip to main content

UMAP, but optimized with jax.

Project description

umapjax

Tests

UMAP, but optimized with jax. (Experimental implementation)

umapjax inherits the API of umap-learn. The UmapJax class is a drop-in replacement for umap.UMAP, with a few key differences:

  1. umapjax does not support densmap.
  2. umapjax does not support output_metric other than euclidean.

Note: umapjax does not fully replicate umap-learn and care should be used when interpreting results.

This package is intended to be used in combination with accelerated hardware like GPUs and TPUs. There is no benefit to using umapjax on a CPU.

Getting started

import umapjax

model = umapjax.UmapJax(n_neighbors=15)
embedding = model.fit_transform(X)

Implementation details

The implementaion used in umapjax is very similar to the one used in umap-learn; however, rather than a single step updating one single point, we update a set of points in parallel using jax. The gradients of the points are weighted by edge weights, which control sampling frequencies in the original algorithm. If results look strange, try changing n_epochs or batch_size. The batch_size argument can also be used to control acceleration on GPUs/TPUs.

Installation

You need to have Python 3.11 or newer installed on your system. If you don't have Python installed, we recommend installing uv.

There are several alternative options to install umapjax:

  1. Install the latest release of umapjax from PyPI:
pip install umapjax
  1. Install the latest development version:
pip install git+https://github.com/adamgayoso/umapjax.git@main

Release notes

See the changelog.

Contact

If you found a bug, please use the issue tracker.

Citation

t.b.a

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

umapjax-0.0.2.tar.gz (91.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

umapjax-0.0.2-py3-none-any.whl (14.0 kB view details)

Uploaded Python 3

File details

Details for the file umapjax-0.0.2.tar.gz.

File metadata

  • Download URL: umapjax-0.0.2.tar.gz
  • Upload date:
  • Size: 91.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.0.2.tar.gz
Algorithm Hash digest
SHA256 89de5248ba95e3130dbce7749c5c8f52b94c643baf0e761a0de2d33201094790
MD5 c107f3091782118e444ced1bf36ae1b3
BLAKE2b-256 782fa4b832bd49a899c5dfcdf838e5a4171d38696e2d25242c91e7e70d52c687

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.0.2.tar.gz:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file umapjax-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: umapjax-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 14.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 6aecf14778ee3d13433ae876a6de45c339e911ebe923ad7ef26f75d203b369c4
MD5 8b55de1d13932c535035b99fb17668ca
BLAKE2b-256 75ad85e6c7f2ec498ef912bbcb20de5dd3628670c00654f874b3919c198a047c

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.0.2-py3-none-any.whl:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page