Skip to main content

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Project description

JAX k-D

Find k-nearest neighbors using a k-d tree in JAX!

This is an XLA version of two GPU-friendly tree algorithms [1, 2]. It is convenient and lightweight, but the original CUDA implementation [3] may be a better choice depending on the application.

The build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU, but will be much slower than scipy.spatial.KDTree on CPU. The main advantage is to avoid the complexity of using non-JAX libraries and potentially leaving JIT and the GPU when a scalable nearest neighbor search is needed as part of a larger JAX program.

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
neighbors, distances = jk.query_neighbors(tree, queries, 10)
counts = jk.count_neighbors(tree, queries, 0.1)

There is also a simple k-means implementation in jaxkd.extras. More suggestions welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or just grab tree.py.

Notes

  • The demo.ipynb notebook in the source repository has some additional examples, including gradient-based optimization using neighbors and iterative clustering with $k$-means.
  • The query_neighbors function is intended for small values of k and does not use a max heap for simplicity.
  • Some common k-d tree operations such as ball search are not implemented because they do not return a fixed size array. But there are probably others which could be implemented if there is a need. Suggestions welcome!
  • Only the Euclidean distance is currently supported, this relatively easy to change if needed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkd-0.0.2.tar.gz (131.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxkd-0.0.2-py3-none-any.whl (7.9 kB view details)

Uploaded Python 3

File details

Details for the file jaxkd-0.0.2.tar.gz.

File metadata

  • Download URL: jaxkd-0.0.2.tar.gz
  • Upload date:
  • Size: 131.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.0.2.tar.gz
Algorithm Hash digest
SHA256 84c8804579b8454b76448b7a4325466663ad0bf35c34d44fb45d8a369e88fb71
MD5 50f6e187b6eedcf9ac12c48c20680dd8
BLAKE2b-256 6b7cef70d4d57a2be7ecad3a9e176b1cb98edfdce63defd54a59d10cb98ec27e

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.0.2.tar.gz:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxkd-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: jaxkd-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 7.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 5442d3e0ba1e05c3fc5464c3056ae2abb1e6a8c4b8e92e00ee83c500b29e9192
MD5 b841f457bd0119f1d87562e8d9b7ee25
BLAKE2b-256 b0cc148dafc593571f70c3a3d7f2df52f5b6a0833fc9e7a100c083ceaf429eef

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.0.2-py3-none-any.whl:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page