Scripts to auto-generate jax's foreign function interface from CUDA kernels
Project description
JAX-FFI-GEN
A few scripts that help with auto-generating jax's foreign function interface (FFI) binding for CUDA kernels
This code uses tree_sitter to parse CUDA code and jinja2 to auto-generate the corresponding FFI. This is useful for establishing a workflow where one "almost" directly calls CUDA kernels rather than worrying too much about the large amount of boiler plate code that comes with jax's FFI.
It is recommended to put a little python script next to the CUDA source files and execute it every time one needs to regenerate some source file, because the corresponding kernel interface changed.
from pathlib import Path
from jax_ffi_gen import parse, generator
HERE = Path(__file__).resolve().parent
kernels = parse.get_functions_from_file(str(HERE / "my_kernels.cuh"), only_kernels=True)
generator.generate_ffi_module_file(
output_file = str(HERE / "generated/ffi_my_kernels.cu"),
functions = kernels
)
The parser will interprete each argument of your kernel as follows:
const *pointer type as an input jax array*modifiable pointer type as an output jax arrayconstas a static parameter that needs to be known at jit-compile time
Some useful customization options
kernels = parse.get_functions_from_file(
str(HERE / "my_kernels.cuh"),
names = ["MyKernelA", "MyKernelB"] # only select some kernels by name
)
# Examples of a few useful features that you may need to define per kenrel
kernels["MyKernelA"].init_outputs_zero = True
kernels["MyKernelA"].grid_size_expression = "x.element_count()"
kernels["MyKernelA"].block_size_expression = "64"
kernels["MyKernelA"].smem_size_expression = "blockDim.x * sizeof(float4)" # dynamic shared memory
kernels["MyKernelA"].par["num_particles"].expression = "x.element_count()/3"
kernels["MyKernelA"].template_par["p"].instances = (0,1,2)
generator.generate_ffi_module_file(
output_file = str(HERE / "generated/ffi_new_kernels.cu"),
functions = kernels,
includes = ["../math.cuh"] # set includes
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jax_ffi_gen-0.5.0.tar.gz.
File metadata
- Download URL: jax_ffi_gen-0.5.0.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5bca251f6bb2fe9862612f66db601bf7e89b3c4e0bb6b3631c6d7b10aa2de9f3
|
|
| MD5 |
0624b5bd8fa85921aebff0ace08115e8
|
|
| BLAKE2b-256 |
f3f658e832b21c8a01d6bcc7265078ecd6105e9bbf25f175ca66e72ca8c97e85
|
File details
Details for the file jax_ffi_gen-0.5.0-py3-none-any.whl.
File metadata
- Download URL: jax_ffi_gen-0.5.0-py3-none-any.whl
- Upload date:
- Size: 6.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bdae6200d428ac391bb7d33a989782c11516c3ab73260bb855242856089d53e4
|
|
| MD5 |
86283474e889c8049ce9405fefedb829
|
|
| BLAKE2b-256 |
f520ccb45fad56e757bf90a19286febf65a425e72ea11e5241f62695da858f7f
|