Skip to main content

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Project description

JAX k-D

Find k-nearest neighbors using a k-d tree in JAX!

This is an implementation of two GPU-friendly tree algorithms [1, 2] using only JAX primitives. The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU/TPU, but will be slower than SciPy's KDTree on CPU. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see jaxkd.extras).

If query speed is the performance bottleneck and you only use Nvidia GPUs, the jaxkd-cuda extension can be installed as an optional dependency (see below) to enable more efficient tree operations, particularly traversal. The intention is to match the behavior of the pure-JAX version and integrate seamlessly with the cuda=True argument. Building the extension will require CMake and NVCC installed on your system. There may be some rough edges and the internal workings may change.

For even more power, flexibility, and speed, consider binding the original cudaKDTree library to JAX. Functionality will be different as described in the jaxkd-cuda repository, where example bindings can also be found and modified to your needs. Be warned that these will not spark joy. The advantage of the pure-JAX version is that it is portable and easy to use, with the ability to scale up to larger problems without the complexity of integrating non-JAX libraries. Try it out!

Open In Colab

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)

There is also a one-step build_and_query for convenience, and all these functions accept cuda=True to use the CUDA extension if it is installed.

Additional helpful functionality can be found in jaxkd.extras.

  • query_neighbors_pairwise and count_neighbors_pairwise for brute-force neighbor searches
  • k_means for clustering using k-means++ initialization, thanks to @NeilGirdhar for contributions

Suggestions and contributions for other extras are always welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or with the CUDA extension.

python -m pip install jaxkd[cuda]

Or just grab tree.py.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jaxkd-0.1.2.tar.gz (9.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jaxkd-0.1.2-py3-none-any.whl (11.4 kB view details)

Uploaded Python 3

File details

Details for the file jaxkd-0.1.2.tar.gz.

File metadata

  • Download URL: jaxkd-0.1.2.tar.gz
  • Upload date:
  • Size: 9.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.2.tar.gz
Algorithm Hash digest
SHA256 842e7837450feec39fa03bb6f5369fa4123ec526339dc359bfba7157884b77ec
MD5 35da2db044a432d9f9f0add72812c097
BLAKE2b-256 fc6b40de65a682f0212ff13facf0c2a9b1a07e1b0f719906c11b39a9855275b3

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.2.tar.gz:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file jaxkd-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: jaxkd-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 11.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for jaxkd-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 63560da791809a6c2cc7f3bf7e9d1509f2d880ba1bb7e6695727e6a41032bd98
MD5 f230b9b8ab3da4917b2a152d76ea7858
BLAKE2b-256 283153a56d0bb104b6726f6d6858304fe7f4fb63fb120112f446ebecae5168c4

See more details on using hashes here.

Provenance

The following attestation bundles were made for jaxkd-0.1.2-py3-none-any.whl:

Publisher: publish.yml on dodgebc/jaxkd

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page