Skip to main content

UMAP, but optimized with jax.

Project description

umapjax

Tests

UMAP, but optimized with jax. (Experimental implementation)

umapjax inherits the API of umap-learn. The UmapJax class is a drop-in replacement for umap.UMAP, with a few key differences:

  1. umapjax does not support densmap.
  2. umapjax does not support output_metric other than euclidean.

Note: umapjax does not fully replicate umap-learn and care should be used when interpreting results.

This package is intended to be used in combination with accelerated hardware like GPUs and TPUs. There is no benefit to using umapjax on a CPU.

Getting started

import umapjax

model = umapjax.UmapJax(n_neighbors=15)
embedding = model.fit_transform(X)

Implementation details

The implementaion used in umapjax is very similar to the one used in umap-learn; however, rather than a single step updating one single point, we update a set of points in parallel using jax. The gradients of the points are weighted by edge weights, which control sampling frequencies in the original algorithm. If results look strange, try changing n_epochs or batch_size. The batch_size argument can also be used to control acceleration on GPUs/TPUs.

Installation

You need to have Python 3.11 or newer installed on your system. If you don't have Python installed, we recommend installing uv.

There are several alternative options to install umapjax:

  1. Install the latest release of umapjax from PyPI:
pip install umapjax
  1. Install the latest development version:
pip install git+https://github.com/adamgayoso/umapjax.git@main

Release notes

See the changelog.

Contact

If you found a bug, please use the issue tracker.

Citation

t.b.a

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

umapjax-0.0.3.tar.gz (96.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

umapjax-0.0.3-py3-none-any.whl (16.6 kB view details)

Uploaded Python 3

File details

Details for the file umapjax-0.0.3.tar.gz.

File metadata

  • Download URL: umapjax-0.0.3.tar.gz
  • Upload date:
  • Size: 96.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.0.3.tar.gz
Algorithm Hash digest
SHA256 2c1daec160ffad63f61097aa631959fbf6f34e022e5e846ced88f28028c9d8e1
MD5 3dc7d02b0928eb46277a581221ba7f2f
BLAKE2b-256 599036e281090ec5adb55e317e3cbb868cac29810e3c7eb2a065b1d484cc5293

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.0.3.tar.gz:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file umapjax-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: umapjax-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 16.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for umapjax-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 dbe6a47cc4f0f3d933a60d6c76ce4c94a58f654ef8cd8c631069382dbce89a41
MD5 d419ec158a843271aa02144e77a1078b
BLAKE2b-256 ef131984aa8a9bbb43a9382ba7c0387a3ed808272aa14bc35e8d04f0245ef883

See more details on using hashes here.

Provenance

The following attestation bundles were made for umapjax-0.0.3-py3-none-any.whl:

Publisher: release.yaml on adamgayoso/umapjax

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page