Skip to main content

Generate PyTorch Custom Operators from Numba-CUDA kernels

Project description

Pytorch-Numba Extension JIT

Documentation

Writing custom CUDA operators in C and CPP can make certain operations significantly more efficient, but requires setting up a full C++ project and involves a great deal of boilerplate. Writing CUDA kernels using numba-cuda is significantly easier, but incurs overhead on every call, and still requires some boilerplate to integrate with the tracing systems that underlie torch.compile.

However, many of the CUDA kernels that would be used for deep learning are relatively similar (read from a set of input arrays, write to output arrays). As such, most of the boilerplate and binding code for C++ extensions could be generated automatically.

This project aims to do exactly that: pnex.jit takes a Python function in the form of a Numba CUDA kernel, along with some type annotations, and compiles a user-friendly and highly-performant PyTorch C++ extension.

Additionally, if a convenient wrapper for PyTorch Custom Operators is all that is desired, this library also allows skipping the C++ compilation phase and only generating the boilerplate for a Custom Operator definition.

For an example usage of this package, see my other package pytorch-nd-semiconv

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_numba_extension_jit-0.1.0.tar.gz (54.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pytorch_numba_extension_jit-0.1.0-py3-none-any.whl (17.8 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_numba_extension_jit-0.1.0.tar.gz.

File metadata

File hashes

Hashes for pytorch_numba_extension_jit-0.1.0.tar.gz
Algorithm Hash digest
SHA256 508cd5a47e43eadf64183db6da0e1929874760cc283d67dbadbbb645bae832f5
MD5 dc0695a78377d1490ec867e23abfa0d6
BLAKE2b-256 87ca45d79ed6968f35a954377d44efd5e41fca36b30b983f88b47bed98ba25dd

See more details on using hashes here.

File details

Details for the file pytorch_numba_extension_jit-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_numba_extension_jit-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 503f022db1aa64c761152312f77f97dea3966c23aeac6daecee3c413cd3d587e
MD5 7469affca5a369080284b3becc0bb939
BLAKE2b-256 3bdf411d456b05162f4aeb5002dd37c7935dae70365957578e329e7109a5d91c

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page