Python (ctypes) bindings for the kernel-set LLM inference & training kernels
Project description
kernel_set — Python bindings
Pure-ctypes Python bindings for kernel-set, a library of high-performance
CUDA/HIP kernels for LLM inference and training (norm, activation, attention,
gemm, moe, rope, quant, sampling, embedding, elementwise, loss, optimizer).
- No compilation on install. The package
dlopens a prebuilt shared library (libkernel_set.so/libkernel_set.dylib/kernel_set.dll). - torch-friendly, but torch-optional. Every wrapper accepts a
torch.Tensoror a raw integer device pointer.torchis only needed for the tensor-convenience paths. - Full ABI coverage — all ~60 C functions, plus runtime/device/stream/memory
helpers and a
torch <-> ks_dtypemapping.
Install
pip install ./bindings/python # editable: pip install -e ./bindings/python
The binding has no required dependencies. Install torch for the ergonomic tensor paths:
pip install "kernel_set[torch]"
Locating the shared library
At import time the package searches, in order:
KERNEL_SET_LIB— full path or bare filename of the library (e.g.export KERNEL_SET_LIB=/opt/kernel_set/lib/libkernel_set.so).KERNEL_SET_LIB_DIR— a directory to search for the platform basename.- Next to the installed package (
kernel_set/,kernel_set/lib/) — for wheels that vendor the.so. - Repo build trees relative to the source checkout (
build/,build/lib/,lib/,install/lib/). - Common system paths (
/usr/local/lib,/usr/lib,/opt/kernel_set/lib) and the platform loader's own search path (LD_LIBRARY_PATH,DYLD_LIBRARY_PATH,PATH).
If none is found, a clear OSError lists every path tried.
Usage
import torch
import kernel_set as ks
print(ks.version(), ks.backend_name()) # "0.1.0" "cuda"
# --- RMSNorm (dtype, shape, and stream all inferred from the tensors) -------
x = torch.randn(8, 4096, device="cuda", dtype=torch.float16)
w = torch.ones(4096, device="cuda", dtype=torch.float16)
out = torch.empty_like(x)
ks.norm.rms_norm(out, x, w, eps=1e-6)
# --- SwiGLU MLP gate ------------------------------------------------------
gate = torch.randn(8, 11008, device="cuda", dtype=torch.bfloat16)
up = torch.randn(8, 11008, device="cuda", dtype=torch.bfloat16)
y = torch.empty_like(gate)
ks.activation.swiglu(y, gate, up)
# --- Greedy decode --------------------------------------------------------
logits = torch.randn(2, 32000, device="cuda", dtype=torch.float16)
tokens = torch.empty(2, device="cuda", dtype=torch.int32)
ks.sampling.argmax(tokens, logits, num_seqs=2, vocab_size=32000)
torch.cuda.synchronize()
print(tokens)
Raw device pointers (no torch)
import kernel_set as ks
from kernel_set import DType
rows, cols = 8, 4096
nbytes = rows * cols * 2 # float16
x = ks.runtime.malloc_device(nbytes)
w = ks.runtime.malloc_device(cols * 2)
out = ks.runtime.malloc_device(nbytes)
# ...fill x/w via ks.runtime.memcpy(..., ks.MemcpyKind.HOST_TO_DEVICE)...
ks.norm.rms_norm(out, x, w, rows=rows, cols=cols, eps=1e-6,
dtype=DType.F16, stream=0)
ks.runtime.stream_synchronize(0)
for p in (x, w, out):
ks.runtime.free_device(p)
Tensor-library interop (passing GPU pointers)
The wrappers extract device pointers and streams from the ecosystem's tensor lib for you. The rules:
- Pointers come from
tensor.data_ptr(). Tensors must be CUDA and contiguous (call.contiguous()first); both are validated and raise a clearValueErrorotherwise. Any object exposingdata_ptr()(e.g. CuPy arrays, Numba device arrays) works too. - Streams: if you don't pass
stream=, the wrappers default totorch.cuda.current_stream(tensor.device).cuda_streamfor torch inputs so launches are ordered against your torch work. For raw-pointer inputs the default is the default stream (0). You can override withstream=<int>,stream=torch.cuda.Stream(...),stream="current", or akernel_set.runtime.Stream. - dtype: inferred from the torch tensor's
.dtypevia theTORCH_TO_KSmap; passdtype=ks.DType.*for raw pointers or to override.
import torch, kernel_set as ks
a = torch.randn(256, 512, device="cuda", dtype=torch.bfloat16)
b = torch.randn(512, 128, device="cuda", dtype=torch.bfloat16)
c = torch.empty(256, 128, device="cuda", dtype=torch.bfloat16)
# C = A @ B (M, N, K). Run on a custom stream:
s = torch.cuda.Stream()
with torch.cuda.stream(s):
ks.gemm.gemm(c, a, b, m=256, n=128, k=512, stream=s)
s.synchronize()
dtype mapping (torch <-> ks_dtype)
| torch dtype | ks.DType |
|---|---|
float32 |
F32 |
float16 |
F16 |
bfloat16 |
BF16 |
float64 |
F64 |
float8_e4m3fn |
F8E4M3 |
float8_e5m2 |
F8E5M2 |
int64 |
I64 |
int32 |
I32 |
int8 |
I8 |
uint8 |
U8 |
(ks.DType.I4 is packed int4 with no native torch dtype — pass raw pointers.)
ks.dtype_to_ks(torch.bfloat16) # -> ks.DType.BF16
ks.ks_to_torch_dtype(ks.DType.F16) # -> torch.float16
Error handling
Every C call returns a ks_status_t; non-zero statuses raise
kernel_set.KernelSetError, which carries the symbolic status name (from
ks_status_string) and the thread-local backend message (from
ks_last_error_string):
try:
ks.norm.rms_norm(out, x, w, eps=1e-6)
except ks.KernelSetError as e:
print(e.status, e.status_name, e.backend_message)
Low-level FFI
For full control, the raw ctypes layer is available — every C function with its
argtypes/restype set:
from kernel_set._lib import lib, check, KS_DTYPE_F16
check(lib.ks_rms_norm(out_ptr, x_ptr, w_ptr, rows, cols, 1e-6, KS_DTYPE_F16, 0),
"ks_rms_norm")
API surface
| Module | Functions |
|---|---|
kernel_set.runtime |
version, backend_name, dtype_size_bits, dtype_name, device_count, get/set_device, get_device_properties, Stream, stream_create/destroy/synchronize, malloc_device, free_device, memcpy, memset_device |
kernel_set.norm |
rms_norm, rms_norm_residual, layer_norm, rms_norm_backward, layer_norm_backward |
kernel_set.activation |
silu, gelu, relu, swiglu, swiglu_packed, geglu, swiglu_backward |
kernel_set.attention |
flash_attn, flash_attn_varlen, paged_attn_decode, reshape_and_cache, mla_decode, flash_attn_backward |
kernel_set.gemm |
gemm, gemm_bias_act, gemm_batched, gemm_w8a8, gemm_w4a16 |
kernel_set.moe |
gate_softmax_topk, gate_sigmoid_group_topk, compute_permutation, permute, unpermute, grouped_gemm |
kernel_set.rope |
rope_inplace, rope, rope_gather, rope_backward |
kernel_set.quant |
quantize_fp8, dequantize_fp8, quantize_int8, dequantize_int8, dequantize_int4 |
kernel_set.sampling |
softmax, log_softmax, argmax, sample |
kernel_set.embedding |
embedding_lookup, embedding_backward |
kernel_set.elementwise |
add, mul, add_residual, scale, cast, axpby |
kernel_set.loss |
cross_entropy, fused_linear_cross_entropy |
kernel_set.optimizer |
adamw, sgd_momentum, global_grad_norm |
Enums: ks.DType, ks.Activation, ks.QuantMode, ks.MemcpyKind, ks.Status.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file kernel_set-0.1.0-py3-none-manylinux_2_35_x86_64.whl.
File metadata
- Download URL: kernel_set-0.1.0-py3-none-manylinux_2_35_x86_64.whl
- Upload date:
- Size: 18.5 MB
- Tags: Python 3, manylinux: glibc 2.35+ x86-64
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.11.14
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
788f8fa6fcedb99fb9ce8a86c1a8e30cbbee192c99ef5465182d4e242fc3ab88
|
|
| MD5 |
1e24f697a29db15a3e08aab8f68603b1
|
|
| BLAKE2b-256 |
5ae4c6a75188854d4ff8bc7b0a0beaa8c2c2f4b8cfca67d03f60dba3277dbdd8
|