Skip to main content

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Project description

JAX $k$-D

Find $k$-nearest neighbors using a $k$-d tree in JAX!

This is an implementation of two GPU-friendly tree algorithms [1, 2] using only XLA primitives. It is convenient and lightweight, but the CUDA-based cudaKDTree may be a better choice depending on the application.

The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU, but will be much slower on CPU than scipy.spatial.KDTree. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see extras.py). The main advantage of jaxkd is the ability to scale up to larger problems without the complexity of integrating non-JAX libraries, especially when the neighbor search should not be the primary computational load.

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, 0.1)
neighbors, distances = jk.query_neighbors(tree, queries, 10)

There is also jaxkd.extras with simple brute force versions of these for comparison, as well as a starter k-means implementation. More suggestions welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or just grab tree.py.

Notes

  • The demo.ipynb notebook in the source repository has some additional examples, including gradient-based optimization using neighbors and iterative clustering with $k$-means.
  • The query_neighbors function is intended for small values of $k$ and does not use a max heap for simplicity.
  • Some common $k$-d tree operations such as ball search are not implemented because they do not return a fixed size array. But there are probably others which could be implemented if there is a need. Suggestions welcome!
  • Only the Euclidean distance is currently supported, this relatively easy to change if needed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkd-0.0.3.tar.gz (131.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxkd-0.0.3-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file jaxkd-0.0.3.tar.gz.

File metadata

  • Download URL: jaxkd-0.0.3.tar.gz
  • Upload date:
  • Size: 131.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.0.3.tar.gz
Algorithm Hash digest
SHA256 4903bc0c8fc1206d71b38560356f764197a4cafd461e960f053039d0a42f4089
MD5 0d211e8baced7166e002eb9eabfb7b7c
BLAKE2b-256 464a349409b7348a619e254827a14600c2cfa86ec27b53d375c4eb9d3ed57221

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.0.3.tar.gz:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxkd-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: jaxkd-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 8.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 636606814a3d7dda15df2163a993cc29846438cbeea3c5274056320a40395626
MD5 322aa931a122fe65e76531152c850335
BLAKE2b-256 714fc6c7771308c5fa86f915452326ad9dce9b8feac1ca6741ff7a6bbda896bb

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.0.3-py3-none-any.whl:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page