UMAP, but optimized with jax.
Project description
umapjax
UMAP, but accelerated. (Experimental implementation)
umapjax inherits the API of umap-learn. The UmapJax class is a drop-in replacement for umap.UMAP, with a few key differences:
umapjaxdoes not supportdensmap.umapjaxdoes not supportoutput_metricother thaneuclidean.
Note: umapjax does not fully replicate umap-learn and care should be used when interpreting results.
This package implements the following backends (despite being named umapjax):
torch(PyTorch)mx(MLX)jax(JAX)
Getting started
import umapjax
layout_backend: Literal["jax", "mx", "torch"] = "jax"
spectral_backend: Literal["jax", "scipy", "torch"] = "scipy"
batch_size: int | None = None # Defaults to X.shape[0]
model = umapjax.UmapJax(
n_neighbors=15,
layout_backend=layout_backend,
spectral_backend=spectral_backend
)
embedding = model.fit_transform(X)
If the optimization is slow, try increasing the batch size as a multiple of X.shape[0]. All backends will automatically use accelerated hardware if available.
If using "torch", you can set umapjax.layouts.torch.TORCH_DEVICE and umapjax.spectral.torch.TORCH_DEVICE to control the default device used for the layout and spectral embedding, respectively.
Implementation details
The implementaion used in umapjax is very similar to the one used in umap-learn; however, rather than a single step updating one single point, we update a set of points in parallel using jax. The gradients of the points are weighted by edge weights, which control sampling frequencies in the original algorithm. If results look strange, try changing n_epochs or batch_size. The batch_size argument can also be used to control acceleration on GPUs/TPUs.
Installation
You need to have Python 3.11 or newer installed on your system. If you don't have Python installed, we recommend installing uv.
There are several alternative options to install umapjax:
- Install the latest release of
umapjaxfrom PyPI with a preferred backend:
pip install "umapjax[jax,mlx,torch]"
- Install the latest development version:
pip install "umapjax[jax,mlx,torch] @ git+https://github.com/adamgayoso/umapjax.git@main"
Contact
If you found a bug, please use the issue tracker.
Citation
t.b.a
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file umapjax-0.1.0.tar.gz.
File metadata
- Download URL: umapjax-0.1.0.tar.gz
- Upload date:
- Size: 127.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1fc8b994bbaf155d7d08d27d3ee79ef6189d8ab5e46e8b24d3664cb331d07471
|
|
| MD5 |
55687bae60583e8775d39f2c57c569ab
|
|
| BLAKE2b-256 |
63b43ead1e48a4df5908ddda2c1cf5ae91fcb36b3cc204df048ab669fbc53c36
|
Provenance
The following attestation bundles were made for umapjax-0.1.0.tar.gz:
Publisher:
release.yaml on adamgayoso/umapjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
umapjax-0.1.0.tar.gz -
Subject digest:
1fc8b994bbaf155d7d08d27d3ee79ef6189d8ab5e46e8b24d3664cb331d07471 - Sigstore transparency entry: 872096747
- Sigstore integration time:
-
Permalink:
adamgayoso/umapjax@950a76d949d1857a2f5bbcf00c5026a48db6c091 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/adamgayoso
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yaml@950a76d949d1857a2f5bbcf00c5026a48db6c091 -
Trigger Event:
release
-
Statement type:
File details
Details for the file umapjax-0.1.0-py3-none-any.whl.
File metadata
- Download URL: umapjax-0.1.0-py3-none-any.whl
- Upload date:
- Size: 26.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c06e5ba443d8a28f10c4908355839349402a23c69ef492a1be7b3cc876e999eb
|
|
| MD5 |
0ba8cd71aac109de8f858922ef0d78e4
|
|
| BLAKE2b-256 |
00bcdc42366d5f3d13cff3b2513c94524f5393984d4b9806a27fdcfaba168dbd
|
Provenance
The following attestation bundles were made for umapjax-0.1.0-py3-none-any.whl:
Publisher:
release.yaml on adamgayoso/umapjax
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
umapjax-0.1.0-py3-none-any.whl -
Subject digest:
c06e5ba443d8a28f10c4908355839349402a23c69ef492a1be7b3cc876e999eb - Sigstore transparency entry: 872096752
- Sigstore integration time:
-
Permalink:
adamgayoso/umapjax@950a76d949d1857a2f5bbcf00c5026a48db6c091 -
Branch / Tag:
refs/tags/v0.1.0 - Owner: https://github.com/adamgayoso
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
release.yaml@950a76d949d1857a2f5bbcf00c5026a48db6c091 -
Trigger Event:
release
-
Statement type: