Skip to main content

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Project description

JAX k-D

Find k-nearest neighbors using a k-d tree in JAX!

This is an implementation of two GPU-friendly tree algorithms [1, 2] using only XLA primitives. It is convenient and lightweight, but the CUDA-based cudaKDTree may be a better choice depending on the application.

The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU, but will be much slower on CPU than SciPy's KDTree. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see jaxkd.extras below).

The main advantage of jaxkd is the ability to scale up to larger problems without the complexity of integrating non-JAX libraries, especially when the neighbor search should not be the primary computational load.

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)

Additional helpful functionality can be found in jaxkd.extras.

  • query_neighbors_pairwise and count_neighbors_pairwise for brute-force neighbor searches
  • k_means for clustering using k-means++ initialization, thanks to @NeilGirdhar for contributions

Suggestions and contributions for other extras are always welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or just grab tree.py.

Notes

  • The demo.ipynb notebook in the source repository has some additional examples.
  • The query_neighbors function is intended for small values of k and does not use a max heap for simplicity.
  • Some common k-d tree operations such as ball search are not implemented because they do not return a fixed size array. But there are probably others which could be implemented if there is a need. Suggestions welcome!
  • Only the Euclidean distance is currently supported, this relatively easy to change if needed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkd-0.1.0.tar.gz (134.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxkd-0.1.0-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file jaxkd-0.1.0.tar.gz.

File metadata

  • Download URL: jaxkd-0.1.0.tar.gz
  • Upload date:
  • Size: 134.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.0.tar.gz
Algorithm Hash digest
SHA256 44abfda3b07e09039d2c7e853ab88b1230f98bb84fa135ab53c3415d70f904ef
MD5 6f2bf7bed6c722f6e9cf50266855c876
BLAKE2b-256 8cc8a11c7e43193600eb2d4ca321605659c7478c84dc44cb1db1f48fd7e7399e

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.0.tar.gz:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxkd-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: jaxkd-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9e94f53558810922072b3866dadd8c903120dd68de3eadf9a84e5d33c7aed785
MD5 084ca7f247cd0572ddc37652e7728a97
BLAKE2b-256 3e0857fae683d6ed66e49bc6181b34d2d046bbd38aff4207017bd9793951c381

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.0-py3-none-any.whl:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page