Skip to main content

Scripts to auto-generate jax's foreign function interface from CUDA kernels

Project description

JAX-FFI-GEN

A few scripts that help with auto-generating jax's foreign function interface (FFI) binding for CUDA kernels This code uses tree_sitter to parse CUDA code and jinja2 to auto-generate the corresponding FFI. This is useful for establishing a workflow where one "almost" directly calls CUDA kernels rather than worrying too much about the large amount of boiler plate code that comes with jax's FFI.

It is recommended to put a little python script next to the CUDA source files and execute it every time one needs to regenerate some source file, because the corresponding kernel interface changed.

from pathlib import Path
from jax_ffi_gen import parse, generator

HERE = Path(__file__).resolve().parent

kernels = parse.get_functions_from_file(str(HERE / "my_kernels.cuh"), only_kernels=True)

generator.generate_ffi_module_file(
    output_file = str(HERE / "generated/ffi_my_kernels.cu"), 
    functions = kernels
)

The parser will interprete each argument of your kernel as follows:

  • const * pointer type as an input jax array
  • * modifiable pointer type as an output jax array
  • const as a static parameter that needs to be known at jit-compile time

Some useful customization options

kernels = parse.get_functions_from_file(
    str(HERE / "my_kernels.cuh"), 
    names = ["MyKernelA", "MyKernelB"] # only select some kernels by name
)

# Examples of a few useful features that you may need to define per kenrel
kernels["MyKernelA"].init_outputs_zero = True
kernels["MyKernelA"].grid_size_expression = "x.element_count()"
kernels["MyKernelA"].block_size_expression = "64"
kernels["MyKernelA"].smem_size_expression = "blockDim.x * sizeof(float4)" # dynamic shared memory
kernels["MyKernelA"].par["num_particles"].expression = "x.element_count()/3"
kernels["MyKernelA"].template_par["p"].instances = (0,1,2)

generator.generate_ffi_module_file(
    output_file = str(HERE / "generated/ffi_new_kernels.cu"), 
    functions = kernels,
    includes = ["../math.cuh"] # set includes
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

jax_ffi_gen-0.5.0.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

jax_ffi_gen-0.5.0-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file jax_ffi_gen-0.5.0.tar.gz.

File metadata

  • Download URL: jax_ffi_gen-0.5.0.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jax_ffi_gen-0.5.0.tar.gz
Algorithm Hash digest
SHA256 5bca251f6bb2fe9862612f66db601bf7e89b3c4e0bb6b3631c6d7b10aa2de9f3
MD5 0624b5bd8fa85921aebff0ace08115e8
BLAKE2b-256 f3f658e832b21c8a01d6bcc7265078ecd6105e9bbf25f175ca66e72ca8c97e85

See more details on using hashes here.

File details

Details for the file jax_ffi_gen-0.5.0-py3-none-any.whl.

File metadata

  • Download URL: jax_ffi_gen-0.5.0-py3-none-any.whl
  • Upload date:
  • Size: 6.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for jax_ffi_gen-0.5.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bdae6200d428ac391bb7d33a989782c11516c3ab73260bb855242856089d53e4
MD5 86283474e889c8049ce9405fefedb829
BLAKE2b-256 f520ccb45fad56e757bf90a19286febf65a425e72ea11e5241f62695da858f7f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page