Skip to main content

CUDA kernel library for Kestrel

Project description

kestrel-kernels

Precompiled CUDA kernels for Kestrel, a high-performance inference engine for Moondream, the world's most efficient vision-language model.

License: These kernels are provided for use with Kestrel only. Other use is not permitted.

These kernels target NVIDIA Ampere/Ada/Hopper GPUs (SM80/SM86/SM89/SM90) and are distributed as precompiled shared libraries for fast installation without CUDA compilation.

Kernel Library

CUDA Kernels (compiled via CMake)

These kernels are implemented in CUDA C++ and compiled during wheel build.

activation - GELU Residual Activation

Computes GELU(h) * (g + 1) fused gated activation used in MoE expert layers. The input tensor is split in half: h passes through GELU, g acts as a gate with +1 bias.

Tokens CUDA PyTorch (eager) Compile vs PyTorch
1 3.8 us 64 us 63 us 17x
64 2.9 us 49 us 69 us 17x
740 3.5 us 49 us 68 us 14x
1024 3.9 us 49 us 68 us 13x
2048 5.1 us 49 us 68 us 10x

PyTorch eager launches separate kernels for slice, erf, multiply, and add, with intermediate tensors hitting global memory. Our kernel fuses everything into a single pass. torch.compile is slower than eager here, likely because the dynamic x[:, :hidden] slicing prevents effective fusion.

fused_linear_residual - Linear + Bias + Residual

Fused out = x @ W.T + bias + residual using cuBLASLt epilogues.

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 9.0 us 24 us 2.7x
2 1458 12 us 24 us 2.0x
4 2916 16 us 29 us 1.8x
8 5832 46 us 50 us 1.1x
13 9477 44 us 77 us 1.7x

cuBLASLt epilogues fuse bias addition and residual into the matmul, avoiding extra kernel launches and memory traffic.

fused_mlp - Fused MLP with cuBLASLt

Fused out = residual + gelu(x @ W1.T + b1) @ W2.T + b2 using cuBLASLt epilogues.

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 43 us 56 us 1.3x
2 1458 72 us 89 us 1.2x
4 2916 97 us 124 us 1.3x
8 5832 214 us 259 us 1.2x
13 9477 283 us 379 us 1.3x

MLP is matmul-dominated so the speedup is modest. The gain comes from fusing GELU and residual add into cuBLASLt epilogues.

kv_cache_write - KV Cache Write with FP8 Quantization

Writes BF16 key/value tensors to FP8 paged KV cache with quantization.

Tokens Kestrel vLLM PyTorch (eager) vs vLLM vs PyTorch
1 3.7 us 4.9 us 67 us 1.3x 18x
8 3.5 us 4.8 us 35 us 1.4x 10x
64 3.7 us 4.8 us 35 us 1.3x 9x
256 4.1 us 4.8 us 36 us 1.2x 9x
1024 8.6 us 9.7 us 51 us 1.1x 6x
4096 31 us 46 us 124 us 1.5x 4x

Fused K/V processing and optimized vectorization provide 1.1-1.5x speedup over vLLM's implementation.

layernorm_cuda - Fast LayerNorm Forward

Optimized LayerNorm forward pass for common hidden dimensions.

Vision Encoder (N=1152):

Crops Tokens CUDA PyTorch (eager) vs PyTorch
1 729 3.9 us 8.4 us 2.2x
2 1458 4.2 us 8.4 us 2.0x
4 2916 5.5 us 10 us 1.8x
8 5832 8.3 us 18 us 2.1x
13 9477 18 us 28 us 1.6x

Text Decoder (N=2048):

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 4.2 us 8.4 us 2.0x
prefill 740 3.7 us 8.4 us 2.3x

Specialized kernels for N=1152 and N=2048 use 4 rows/block with warp-only reductions, avoiding shared memory overhead. Two epilogue strategies trade register pressure vs memory bandwidth.

moe_sum - MoE Output Summation

Sums the weighted outputs from top-k MoE experts back into a single hidden state per token. Computes out[t] = sum(expert_outputs[t, 0:k]) where each token selects k=8 experts.

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 3.0 us 5.6 us 1.9x
batch 4 4 3.0 us 5.4 us 1.8x
batch 16 16 2.9 us 5.3 us 1.8x
prefill 740 5.5 us 10 us 1.9x
long 1024 10 us 15 us 1.5x

Vectorized 16-byte loads (8 bf16 at once), fully unrolled k=8 reduction. FP32 accumulation provides better numerical stability than bf16 accumulation. Note: vLLM has a similar kernel, but only supports topk=2,3,4 and falls back to PyTorch for topk=8.

rotary_embedding - Rotary Position Embedding

Applies rotary position embedding to query and key tensors (n_heads=32, head_dim=64).

Context Tokens Kestrel vLLM PyTorch (eager) vs vLLM vs PyTorch
decode 1 3.3 us 4.9 us 118 us 1.5x 36x
batch 4 4 3.1 us 4.5 us 117 us 1.5x 38x
batch 16 16 3.1 us 4.7 us 117 us 1.5x 38x
prefill 740 5.0 us 8.0 us 119 us 1.6x 24x

Vectorized bfloat162 pair processing, shared memory caching of cos/sin values, FP32 math for numerical stability. Split-head kernel for decode increases SM utilization on small batch sizes.

fp8_quant - FP8 Quantization

Converts BF16 tensors to FP8 (e4m3fn) with per-row dynamic scale computation. Used for quantizing MoE activations before FP8 GEMM.

Context Rows CUDA PyTorch (eager) vs PyTorch
decode 8 3.1 us 53 us 17x
batch 4 32 3.1 us 52 us 17x
batch 16 128 3.1 us 52 us 17x
prefill 5920 6.6 us 67 us 10x

Two kernel variants: warp-per-row for large batches (better SM utilization), block-per-row for small batches. Vectorized 16-byte loads/stores, fused absmax reduction.

tau_tail - TAU Attention Scaling

Applies per-head TAU scaling to Q and V in packed QKV. Computes scale = tanh(tok_linear) + tau_pos_table[position] then scales each head: Q *= scale_q, V *= scale_v.

Context Tokens CUDA PyTorch (eager) vs PyTorch
decode 1 4.6 us 45 us 10x
batch 4 4 4.4 us 46 us 10x
batch 16 16 9.0 us 88 us 10x
prefill 740 6.5 us 63 us 10x

CuTe DSL Kernels (precompiled for wheel distribution)

These kernels are written in NVIDIA CuTe DSL (Python) and precompiled to .so files during wheel build. The kernel source templates are excluded from wheel distribution.

topk - Bitonic Top-K Selection

GPU top-k selection using bitonic sort network with optional fused softmax.

Context Tokens Kestrel Quack PyTorch (eager) vs Quack vs PyTorch
decode 1 23 us 29 us 17 us 1.3x 0.8x
batch 16 16 22 us 27 us 17 us 1.2x 0.8x
prefill 740 22 us 28 us 17 us 1.2x 0.7x

Note: Currently slower than PyTorch for N=64, k=8. PyTorch uses radix-based QuickSelect which is more efficient for small N. Algorithm should be revisited.

Python API:

from kestrel_kernels.topk import topk_fwd

values, indices = topk_fwd(scores, k=8, softmax=True)

sampling - Top-p Token Sampling

CuTe DSL rejection-based top-p sampler for probability tensors.

Runtime dispatch uses the CuTe kernel path by default on CUDA, with fallback retained for unsupported cases and runtime errors.

Benchmarks below are H100 (sm90) dispatch-like timings (uniform generation + kernel launch), measured with heavy warmup and interleaved randomized runs:

Shape (batch, vocab) Kestrel CuTe FlashInfer vs FlashInfer
(1, 51200) 17.37 us 20.78 us 1.20x
(4, 51200) 21.17 us 21.84 us 1.03x
(128, 51200) 38.96 us 42.44 us 1.09x
(32, 1024) 15.25 us 20.50 us 1.34x

Python API:

from kestrel_kernels.sampling import top_p_sampling_from_probs

sampled_ids = top_p_sampling_from_probs(probs, top_p, generator=generator)

cute_moe - MoE Matrix Multiplications

Grouped GEMM kernels for Mixture-of-Experts layers, written in CuTe DSL for H100 (SM90). Supports BF16 and FP8 (W8A8) precision with both warp-level and WGMMA variants, automatically selected based on batch size.

FP8 W8A8 Full MoE Layer (up + activation + down + sum, E=64, k=8, with CUDA Graphs):

Context Tokens Kestrel vLLM (Triton) vs vLLM
decode 1 29 us 51 us 1.72x
batch 4 4 79 us 103 us 1.30x
batch 16 16 146 us 169 us 1.16x
prefill 740 245 us 481 us 1.96x

Python API:

from kestrel_kernels import (
    invoke_cute_moe_up,
    invoke_cute_moe_down,
    invoke_cute_moe_up_fp8,
    invoke_cute_moe_down_fp8,
)

# BF16 up projection
out_up = invoke_cute_moe_up(
    hidden_states, w1, w2,
    topk_weights, topk_ids,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
)

# BF16 down projection
out_down = invoke_cute_moe_down(
    moe_out, w3,
    topk_weights, topk_ids,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
)

moe_align - MoE Token Alignment

Prepares sorted token indices for block-sparse MoE operations. Given topk_ids, outputs sorted token IDs grouped by expert for block-sparse matmul.

Context Tokens Kestrel vLLM vs vLLM
decode 1 6.7 us 9.8 us 1.5x
batch 4 4 6.5 us 9.8 us 1.5x
batch 16 16 7.0 us 10 us 1.4x
prefill 740 12 us 9.2 us 0.8x
long 1024 12 us 9.5 us 0.8x

Uses optimized single-CTA shared-memory histogram for decode (numel < 1024). Prefill path needs optimization.

Python API:

from kestrel_kernels.moe_align import moe_align_block_size

moe_align_block_size(
    topk_ids, num_experts, block_size,
    sorted_token_ids, expert_ids, num_tokens_post_pad,
    expert_map,  # optional for expert parallelism
)

gelu_residual - GELU Residual Activation (CuTe DSL)

CuTe DSL implementation of GELU residual activation for BF16. Computes GELU(h) * (g + 1) fused gated activation used in MoE expert layers. Uses vectorized memory access and streaming stores.

Context Rows CuTe CUDA PyTorch vs CUDA vs PyTorch
decode 8 2.3 us 2.5 us 7.5 us 1.10x 3.3x
batch 4 32 2.4 us 3.0 us 8.6 us 1.24x 3.6x
batch 16 128 2.6 us 2.9 us 8.9 us 1.09x 3.4x
prefill 5920 9.9 us 11.2 us 55.9 us 1.14x 5.6x

fp8_quant_cute - FP8 Quantization (CuTe DSL)

CuTe DSL implementation of FP8 row-wise quantization. Converts BF16 tensors to FP8 (e4m3fn) with per-row dynamic scaling.

hidden=1024 (MoE down projection input):

Context Rows CuTe CUDA vs CUDA
decode 8 2.5 us 2.7 us 1.09x
batch 4 32 2.8 us 3.0 us 1.07x
batch 16 128 2.8 us 3.0 us 1.08x
prefill 5920 5.3 us 6.6 us 1.23x

hidden=2048 (MoE up projection input):

Context Rows CuTe CUDA vs CUDA
decode 8 2.6 us 2.7 us 1.02x
batch 4 32 2.9 us 3.0 us 1.04x
batch 16 128 2.9 us 3.0 us 1.04x
prefill 5920 8.2 us 10.7 us 1.31x

flash_attn - Flash Attention (Prefill & Decode)

Flash Attention kernels written in CuTe DSL, with a dedicated decode path optimized for paged FP8 KV cache. 1.3-2.5x faster than FlashInfer on typical Moondream workloads.

  • FP8 KV cache with per-tensor scaling
  • Paged KV (page_size=1) for fine-grained memory management
  • CUDA graph compatible
  • Causal and prefix-LM masking, variable-length sequences, GQA/MQA

FP8 KV Paged Decode (with CUDA Graphs):

Batch KV Len Kestrel FlashInfer vs FlashInfer
1 740 9.6 us 12.9 us 1.34x
1 1024 8.7 us 13.1 us 1.50x
4 740 17.1 us 23.9 us 1.40x
8 512 10.0 us 25.2 us 2.51x
16 256 9.6 us 17.6 us 1.83x
32 128 11.8 us 26.5 us 2.24x

FP8 KV Paged Prefill:

Seq Len Kestrel FlashInfer vs FlashInfer
740 19.9 us 47.6 us 2.40x
1024 27.3 us 58.9 us 2.16x

Python API:

from kestrel_kernels.flash_attn.cute import flash_attn_func, flash_attn_varlen_func

# Fixed-length attention
out = flash_attn_func(q, k, v, causal=True)

# Variable-length attention
out = flash_attn_varlen_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    max_seqlen_q, max_seqlen_k,
    causal=True,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

kestrel_kernels-0.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (7.7 MB view details)

Uploaded CPython 3.13manylinux: glibc 2.24+ x86-64manylinux: glibc 2.31+ x86-64

kestrel_kernels-0.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (7.7 MB view details)

Uploaded CPython 3.12manylinux: glibc 2.24+ x86-64manylinux: glibc 2.31+ x86-64

kestrel_kernels-0.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (7.7 MB view details)

Uploaded CPython 3.11manylinux: glibc 2.24+ x86-64manylinux: glibc 2.31+ x86-64

kestrel_kernels-0.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl (7.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.24+ x86-64manylinux: glibc 2.31+ x86-64

File details

Details for the file kestrel_kernels-0.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.2.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 0ef29d599015b9c2fa5dc2a316ebf3e33f978451176aac634e4cd1bc4651d6d5
MD5 02adee22fcf22a14e6d59c90ff6e0c3c
BLAKE2b-256 6bb1eeef5cbe620e23d21e7ff459a5020bc4967df64146234e39f78e924cd606

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 ecca20978be7a85efdd7ba2a539396516871f923766112b6ff0a850cf66d86e3
MD5 a1be8f12497a5d9c055fb3773c9b3338
BLAKE2b-256 dc00a0eba6d6e687989bf8bf6485a08edd7b07941054a51b289ece15cea7a0a1

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 0fe7b8b0d5d2710e066ee4b1579cf4cda7edad5a1312827f5f1da33a476f644e
MD5 f214e75e03c7de1f86930566468b96e4
BLAKE2b-256 86650211f24e7479bf3ed3242031bc1c380d7ff44abc24274e24b7a4c76f55cc

See more details on using hashes here.

File details

Details for the file kestrel_kernels-0.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl.

File metadata

File hashes

Hashes for kestrel_kernels-0.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_31_x86_64.whl
Algorithm Hash digest
SHA256 ca44095bba755d90756000330eb935b7df55e8871ab006af225ad981c8415ac8
MD5 1c49fa10dd4ebae57c91f4d7de8e90be
BLAKE2b-256 242d5fd1c0613765bebf48dcb08479edec744011a8287c2e8257e362a4333684

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page