Skip to main content

CUDA kernel library for Kestrel

Project description

kestrel-kernels

Precompiled CUDA kernels for Kestrel, a high-performance inference engine for Moondream, the world's most efficient vision-language model.

License: These kernels are provided for use with Kestrel only. Other use is not permitted.

These kernels are optimized for NVIDIA Ada/Hopper GPUs (SM89/SM90) and distributed as precompiled shared libraries for fast installation without CUDA compilation.

Kernel Library

CUDA Kernels (compiled via CMake)

These kernels are implemented in CUDA C++ and compiled during wheel build.

activation - GELU Residual Activation

Computes GELU(h) * (g + 1) fused gated activation used in MoE expert layers. The input tensor is split in half: h passes through GELU, g acts as a gate with +1 bias.

Tokens CUDA PyTorch (eager) Compile vs PyTorch
1 3.8 us 64 us 63 us 17x
64 2.9 us 49 us 69 us 17x
740 3.5 us 49 us 68 us 14x
1024 3.9 us 49 us 68 us 13x
2048 5.1 us 49 us 68 us 10x

PyTorch eager launches separate kernels for slice, erf, multiply, and add, with intermediate tensors hitting global memory. Our kernel fuses everything into a single pass. torch.compile is slower than eager here, likely because the dynamic x[:, :hidden] slicing prevents effective fusion.

fused_linear_residual - Linear + Bias + Residual

Fused out = x @ W.T + bias + residual using cuBLASLt epilogues.

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 9.0 us 24 us 2.7x
2 1458 12 us 24 us 2.0x
4 2916 16 us 29 us 1.8x
8 5832 46 us 50 us 1.1x
13 9477 44 us 77 us 1.7x

cuBLASLt epilogues fuse bias addition and residual into the matmul, avoiding extra kernel launches and memory traffic.

fused_mlp - Fused MLP with cuBLASLt

Fused out = residual + gelu(x @ W1.T + b1) @ W2.T + b2 using cuBLASLt epilogues.

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 43 us 56 us 1.3x
2 1458 72 us 89 us 1.2x
4 2916 97 us 124 us 1.3x
8 5832 214 us 259 us 1.2x
13 9477 283 us 379 us 1.3x

MLP is matmul-dominated so the speedup is modest. The gain comes from fusing GELU and residual add into cuBLASLt epilogues.

kv_cache_write - KV Cache Write with FP8 Quantization

Writes BF16 key/value tensors to FP8 paged KV cache with quantization.

Tokens Kestrel vLLM PyTorch (eager) vs vLLM vs PyTorch
1 3.7 us 4.9 us 67 us 1.3x 18x
8 3.5 us 4.8 us 35 us 1.4x 10x
64 3.7 us 4.8 us 35 us 1.3x 9x
256 4.1 us 4.8 us 36 us 1.2x 9x
1024 8.6 us 9.7 us 51 us 1.1x 6x
4096 31 us 46 us 124 us 1.5x 4x

Fused K/V processing and optimized vectorization provide 1.1-1.5x speedup over vLLM's implementation.

layernorm_cuda - Fast LayerNorm Forward

Optimized LayerNorm forward pass for common hidden dimensions.

Vision Encoder (N=1152):

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 3.9 us 8.4 us 2.2x
2 1458 4.2 us 8.4 us 2.0x
4 2916 5.5 us 10 us 1.8x
8 5832 8.3 us 18 us 2.1x
13 9477 18 us 28 us 1.6x

Text Decoder (N=2048):

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 4.2 us 8.4 us 2.0x
prefill 740 3.7 us 8.4 us 2.3x

Specialized kernels for N=1152 and N=2048 use 4 rows/block with warp-only reductions, avoiding shared memory overhead. Two epilogue strategies trade register pressure vs memory bandwidth.

moe_sum - MoE Output Summation

Sums the weighted outputs from top-k MoE experts back into a single hidden state per token. Computes out[t] = sum(expert_outputs[t, 0:k]) where each token selects k=8 experts.

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 3.0 us 5.6 us 1.9x
batch 4 4 3.0 us 5.4 us 1.8x
batch 16 16 2.9 us 5.3 us 1.8x
prefill 740 5.5 us 10 us 1.9x
long 1024 10 us 15 us 1.5x

Vectorized 16-byte loads (8 bf16 at once), fully unrolled k=8 reduction. FP32 accumulation provides better numerical stability than bf16 accumulation. Note: vLLM has a similar kernel, but only supports topk=2,3,4 and falls back to PyTorch for topk=8.

rotary_embedding - Rotary Position Embedding

Applies rotary position embedding to query and key tensors (n_heads=32, head_dim=64).

Context Tokens Kestrel vLLM PyTorch (eager) vs vLLM vs PyTorch
decode 1 3.3 us 4.9 us 118 us 1.5x 36x
batch 4 4 3.1 us 4.5 us 117 us 1.5x 38x
batch 16 16 3.1 us 4.7 us 117 us 1.5x 38x
prefill 740 5.0 us 8.0 us 119 us 1.6x 24x

Vectorized bfloat162 pair processing, shared memory caching of cos/sin values, FP32 math for numerical stability. Split-head kernel for decode increases SM utilization on small batch sizes.

fp8_quant - FP8 Quantization

Converts BF16 tensors to FP8 (e4m3fn) with per-row dynamic scale computation. Used for quantizing MoE activations before FP8 GEMM.

Context Rows CUDA PyTorch (eager) vs PyTorch
decode 8 3.1 us 53 us 17x
batch 4 32 3.1 us 52 us 17x
batch 16 128 3.1 us 52 us 17x
prefill 5920 6.6 us 67 us 10x

Two kernel variants: warp-per-row for large batches (better SM utilization), block-per-row for small batches. Vectorized 16-byte loads/stores, fused absmax reduction.

tau_tail - TAU Attention Scaling

Applies per-head TAU scaling to Q and V in packed QKV. Computes scale = tanh(tok_linear) + tau_pos_table[position] then scales each head: Q *= scale_q, V *= scale_v.

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 4.6 us 45 us 10x
batch 4 4 4.4 us 46 us 10x
batch 16 16 9.0 us 88 us 10x
prefill 740 6.5 us 63 us 10x

CuTe DSL Kernels (precompiled for wheel distribution)

These kernels are written in NVIDIA CuTe DSL (Python) and precompiled to .so files during wheel build. The kernel source templates are excluded from wheel distribution.

topk - Bitonic Top-K Selection

GPU top-k selection using bitonic sort network with optional fused softmax.

Context Tokens Kestrel Quack PyTorch (eager) vs Quack vs PyTorch
decode 1 23 us 29 us 17 us 1.3x 0.8x
batch 16 16 22 us 27 us 17 us 1.2x 0.8x
prefill 740 22 us 28 us 17 us 1.2x 0.7x

Note: Currently slower than PyTorch for N=64, k=8. PyTorch uses radix-based QuickSelect which is more efficient for small N. Algorithm should be revisited.

Python API:

from kestrel_kernels.topk import topk_fwd

values, indices = topk_fwd(scores, k=8, softmax=True)

sampling - Top-p Token Sampling

CuTe DSL rejection-based top-p sampler for probability tensors.

Runtime dispatch uses the CuTe kernel path by default on CUDA, with fallback retained for unsupported cases and runtime errors.

Benchmarks below are H100 (sm90) dispatch-like timings (uniform generation + kernel launch), measured with heavy warmup and interleaved randomized runs:

Shape (batch, vocab) Kestrel CuTe FlashInfer vs FlashInfer
(1, 51200) 17.37 us 20.78 us 1.20x
(4, 51200) 21.17 us 21.84 us 1.03x
(128, 51200) 38.96 us 42.44 us 1.09x
(32, 1024) 15.25 us 20.50 us 1.34x

Python API:

from kestrel_kernels.sampling import top_p_sampling_from_probs

sampled_ids = top_p_sampling_from_probs(probs, top_p, generator=generator)

cute_moe - MoE Matrix Multiplications

Grouped GEMM kernels for Mixture-of-Experts layers, written in CuTe DSL for H100 (SM90). Supports BF16 and FP8 (W8A8) precision with both warp-level and WGMMA variants, automatically selected based on batch size.

FP8 W8A8 Full MoE Layer (up + activation + down + sum, E=64, k=8, with CUDA Graphs):

Context Tokens Kestrel vLLM (Triton) vs vLLM
decode 1 29 us 51 us 1.72x
batch 4 4 79 us 103 us 1.30x
batch 16 16 146 us 169 us 1.16x
prefill 740 245 us 481 us 1.96x

Python API:

from kestrel_kernels import (
    invoke_cute_moe_up,
    invoke_cute_moe_down,
    invoke_cute_moe_up_fp8,
    invoke_cute_moe_down_fp8,
)

# BF16 up projection
out_up = invoke_cute_moe_up(
    hidden_states, w1, w2,
    topk_weights, topk_ids,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
)

# BF16 down projection
out_down = invoke_cute_moe_down(
    moe_out, w3,
    topk_weights, topk_ids,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
)

moe_align - MoE Token Alignment

Prepares sorted token indices for block-sparse MoE operations. Given topk_ids, outputs sorted token IDs grouped by expert for block-sparse matmul.

Context Tokens Kestrel vLLM vs vLLM
decode 1 6.7 us 9.8 us 1.5x
batch 4 4 6.5 us 9.8 us 1.5x
batch 16 16 7.0 us 10 us 1.4x
prefill 740 12 us 9.2 us 0.8x
long 1024 12 us 9.5 us 0.8x

Uses optimized single-CTA shared-memory histogram for decode (numel < 1024). Prefill path needs optimization.

Python API:

from kestrel_kernels.moe_align import moe_align_block_size

moe_align_block_size(
    topk_ids, num_experts, block_size,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
    expert_map,  # optional for expert parallelism
)

gelu_residual - GELU Residual Activation (CuTe DSL)

CuTe DSL implementation of GELU residual activation for BF16. Computes GELU(h) * (g + 1) fused gated activation used in MoE expert layers. Uses vectorized memory access and streaming stores.

Context Rows CuTe CUDA PyTorch vs CUDA vs PyTorch
decode 8 2.3 us 2.5 us 7.5 us 1.10x 3.3x
batch 4 32 2.4 us 3.0 us 8.6 us 1.24x 3.6x
batch 16 128 2.6 us 2.9 us 8.9 us 1.09x 3.4x
prefill 5920 9.9 us 11.2 us 55.9 us 1.14x 5.6x

fp8_quant_cute - FP8 Quantization (CuTe DSL)

CuTe DSL implementation of FP8 row-wise quantization. Converts BF16 tensors to FP8 (e4m3fn) with per-row dynamic scaling.

hidden=1024 (MoE down projection input):

Context Rows CuTe CUDA vs CUDA
decode 8 2.5 us 2.7 us 1.09x
batch 4 32 2.8 us 3.0 us 1.07x
batch 16 128 2.8 us 3.0 us 1.08x
prefill 5920 5.3 us 6.6 us 1.23x

hidden=2048 (MoE up projection input):

Context Rows CuTe CUDA vs CUDA
decode 8 2.6 us 2.7 us 1.02x
batch 4 32 2.9 us 3.0 us 1.04x
batch 16 128 2.9 us 3.0 us 1.04x
prefill 5920 8.2 us 10.7 us 1.31x

flash_attn - Flash Attention (Prefill & Decode)

Flash Attention kernels written in CuTe DSL, with a dedicated decode path optimized for paged FP8 KV cache. 1.3-2.5x faster than FlashInfer on typical Moondream workloads.

  • FP8 KV cache with per-tensor scaling
  • Paged KV (page_size=1) for fine-grained memory management
  • CUDA graph compatible
  • Causal and prefix-LM masking, variable-length sequences, GQA/MQA

FP8 KV Paged Decode (with CUDA Graphs):

Batch KV Len Kestrel FlashInfer vs FlashInfer
1 740 9.6 us 12.9 us 1.34x
1 1024 8.7 us 13.1 us 1.50x
4 740 17.1 us 23.9 us 1.40x
8 512 10.0 us 25.2 us 2.51x
16 256 9.6 us 17.6 us 1.83x
32 128 11.8 us 26.5 us 2.24x

FP8 KV Paged Prefill:

Seq Len Kestrel FlashInfer vs FlashInfer
740 19.9 us 47.6 us 2.40x
1024 27.3 us 58.9 us 2.16x

Python API:

from kestrel_kernels.flash_attn.cute import flash_attn_func, flash_attn_varlen_func

# Fixed-length attention
out = flash_attn_func(q, k, v, causal=True)

# Variable-length attention
out = flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
    causal=True,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

kestrel_kernels-0.1.3-cp313-cp313-manylinux_2_34_x86_64.whl (5.0 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.34+ x86-64

kestrel_kernels-0.1.3-cp312-cp312-manylinux_2_34_x86_64.whl (5.0 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.34+ x86-64

kestrel_kernels-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl (5.0 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.34+ x86-64

kestrel_kernels-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl (5.0 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.34+ x86-64

File details

Details for the file kestrel_kernels-0.1.3-cp313-cp313-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.1.3-cp313-cp313-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 4bc87b1557fc45e55c2dab6770e707728a03aa815cc18ecb91f2667d5263bf5a
MD5 7e108dd5104cd70f89a9a535d818b1c3
BLAKE2b-256 60920d76434575ae01e67b9dfe2cd8642bcd380de5eaadbdaafd020544ac3819

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.1.3-cp312-cp312-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.1.3-cp312-cp312-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 c321472bc1053d561eb0417fad4871fb77d6914479bc44ddb68b6077a6393e0d
MD5 9048d3a723c63d9e5891b8eb89187a89
BLAKE2b-256 c11ce412ba18aba7269730b25e666c6eaba7ca09ebe1b4d340082267e408f13d

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.1.3-cp311-cp311-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 1e47e9651e92dc9c2eb22726f46b888cffce04f673766ebaf19cc27fb915d84d
MD5 048b960122d40aace61827d4270d808a
BLAKE2b-256 23ad5d640977f753fb2c4234acba41c9936f41726df628ba3cb18b416e350e41

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.1.3-cp310-cp310-manylinux_2_34_x86_64.whl
Algorithm Hash digest
SHA256 c7910f3148cd28e20abf9a786b2dbb5a202a771595431edfc786a8bb93790dd7
MD5 5733d304400ebc5e4c7be803d8286640
BLAKE2b-256 443ebfc820e4ed0907377e15a98c77c29393bba9066c488219cd08e4cd33ac4b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page