Extension package for jaxkd.
Project description
jaxkd-cuda
This package contains CUDA extensions for JAX k-D. It requires JAX, CMake, and a CUDA compiler (nvcc) to build. It is intended to be installed as an optional dependency to JAX k-D and used as an add-on like so:
python -m pip install jaxkd[cuda]
Note that the cudaKDTree library is more powerful and flexible, and can be bound to JAX using the foreign function interface. See the sample bindings in jaxkd_cuda/cukd for a rough example of how to do this. JaxKDTree also has an example, though it is no longer working with the current JAX API.
This extension uses a slightly different tree-building method to exactly match the behavior of the pure-JAX version. It only permutes an index array and chooses the dimension with the widest spread of points (not largest bounding box) to split. Currently the performance bottleneck is actually the reduce operations needed to compute this. There is also a substantial memory overhead (a few times the number of points), which can probably be reduced in the future. The neighbor query algorithm follows [2] and the neighbor counting is a trivial modification.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file jaxkd_cuda-0.0.0.tar.gz.
File metadata
- Download URL: jaxkd_cuda-0.0.0.tar.gz
- Upload date:
- Size: 131.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
86cb49923df389793a4a917c6caa15821d6a3c41eac753dd87ca0793b85fe645
|
|
| MD5 |
57142abc10042be9677e9a69cacea221
|
|
| BLAKE2b-256 |
e09fafaec109be9ff0d247ba19eebe220464454815e339ca9a85229e6d298ff3
|
Provenance
The following attestation bundles were made for jaxkd_cuda-0.0.0.tar.gz:
Publisher:
publish.yml on dodgebc/jaxkd-cuda
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
jaxkd_cuda-0.0.0.tar.gz -
Subject digest:
86cb49923df389793a4a917c6caa15821d6a3c41eac753dd87ca0793b85fe645 - Sigstore transparency entry: 242601101
- Sigstore integration time:
-
Permalink:
dodgebc/jaxkd-cuda@8ed41baa88506d8a31f44311b91e8ff840e753ec -
Branch / Tag:
refs/tags/v0.0.1 - Owner: https://github.com/dodgebc
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@8ed41baa88506d8a31f44311b91e8ff840e753ec -
Trigger Event:
push
-
Statement type: