Skip to main content

Generate PyTorch Custom Operators from Numba-CUDA kernels

Project description

Pytorch-Numba Extension JIT

Documentation | PyPi

Writing custom CUDA operators in C and CPP can make certain operations significantly more efficient, but requires setting up a full C++ project and involves a great deal of boilerplate. Writing CUDA kernels using numba-cuda is significantly easier, but incurs overhead on every call, and still requires some boilerplate to integrate with the tracing systems that underlie torch.compile.

However, many of the CUDA kernels that would be used for deep learning are relatively similar (read from a set of input arrays, write to output arrays). As such, most of the boilerplate and binding code for C++ extensions could be generated automatically.

This project aims to do exactly that: pnex.jit takes a Python function in the form of a Numba CUDA kernel, along with some type annotations, and compiles a user-friendly and highly-performant PyTorch C++ extension.

Additionally, if a convenient wrapper for PyTorch Custom Operators is all that is desired, this library also allows skipping the C++ compilation phase and only generating the boilerplate for a Custom Operator definition.

For an example usage of this package, see my other package pytorch-nd-semiconv

This package is listed on PyPi; it can be installed with

pip install pytorch-numba-extension-jit

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_numba_extension_jit-0.1.3.tar.gz (58.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_numba_extension_jit-0.1.3-py3-none-any.whl (17.9 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_numba_extension_jit-0.1.3.tar.gz.

File metadata

File hashes

Hashes for pytorch_numba_extension_jit-0.1.3.tar.gz
Algorithm Hash digest
SHA256 2d7c5353aff2ac6c7f24373e0b44b7d7b806a16cd14aaa392a4e3cb8524af402
MD5 327ea43882003f439a3bea80c9b139ce
BLAKE2b-256 d398354f430537406aee0e8c00f391bc6f2ca4313584ba5ec57480a5c9769c07

See more details on using hashes here.

File details

Details for the file pytorch_numba_extension_jit-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_numba_extension_jit-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 aa101ac8664f353a9dc782efe47849abc7dc57ff1cf2fe831a890956d80f1bc2
MD5 a3e5ae401dbf313f9d92f74358a5518b
BLAKE2b-256 96fb50f7e49b84585e544503b32f612e74468b06aef8fabfc08ced5a6e3f14f6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page