Skip to main content

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Project description

JAX k-D

Find k-nearest neighbors using a k-d tree in JAX!

This is an implementation of two GPU-friendly tree algorithms [1, 2] using only JAX primitives. The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU/TPU, but will be slower than SciPy's KDTree on CPU. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see jaxkd.extras).

If neighbor search is the performance bottleneck and you only use Nvidia GPUs, there is a CUDA extension available that can be installed as an optional dependency (see below). The intention is to match the behavior of the pure-JAX version and integrate seamlessly with the cuda=True argument. Building the extension will require CMake and NVCC installed on your system. There may be some rough edges and the internal workings may change.

For even more power and flexibility, consider binding the original cudaKDTree library to JAX. Functionality will be different as described in the jaxkd-cuda repository, where example bindings can also be found. Be warned that they will not spark joy. The advantage of the pure-JAX version is that it is portable and easy to use, with the ability to scale up to larger problems without the complexity of integrating non-JAX libraries. Try it out!

Open In Colab

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)

There is also a one-step build_and_query for convenience, and all these functions accept cuda=True to use the CUDA extension if it is installed.

Additional helpful functionality can be found in jaxkd.extras.

  • query_neighbors_pairwise and count_neighbors_pairwise for brute-force neighbor searches
  • k_means for clustering using k-means++ initialization, thanks to @NeilGirdhar for contributions

Suggestions and contributions for other extras are always welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or with the CUDA extension.

python -m pip install jaxkd[cuda]

Or just grab tree.py.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkd-0.1.1.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxkd-0.1.1-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file jaxkd-0.1.1.tar.gz.

File metadata

  • Download URL: jaxkd-0.1.1.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.1.tar.gz
Algorithm Hash digest
SHA256 83018f4f02ea32f5596befb8fae07ce5f1a65ba88ebcca276d7e4d9bba6f0542
MD5 5b5b778ee7094742b2de4a752ba919ed
BLAKE2b-256 bdd1e77c3a7724e1cf9b4efc78a11608c1f113e5e0c9d7c624bf43120842dcc9

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.1.tar.gz:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxkd-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: jaxkd-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 11.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a9afb5bb137facc48473b32e478134fd1ff22bfa351208f92181f4c1ab062258
MD5 4da3c136894ea6d278fb00a616c56ff1
BLAKE2b-256 334763e50bb349d3d26563e6be729dfba541e707fad834df4a7da9e48d5afa0e

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.1-py3-none-any.whl:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page