Skip to main content

Unapologetically SM120-only CuTe DSL kernels for NVFP4 GEMM and MoE.

Project description

b12x is an SM120-only CuTe DSL kernel library for Blackwell NVFP4 dense GEMM, routed Mixture-of-Experts, and paged attention inference.

It is intentionally narrow. This is not a generic CUDA kernel collection or a full model-serving stack. It does not intend to target any other GPU architectures, including SM100. It is a focused package for a small number of hand-tuned, high-performance SM120 kernels plus the runtime glue needed to launch them cleanly from PyTorch and sglang.

python -m pip install b12x
git clone <repo-url>
cd b12x
python -m pip install -e '.[dev]'
  • Blackwell SM120 GPU

  • CUDA 13 toolchain

  • Python >=3.10,<4.0

  • CUDA 13 PyTorch, torch>=2.10.0

  • nvidia-cutlass-dsl[cu13]==4.4.1

  • FlashInfer available if you want reference and benchmark comparisons, but it's not a runtime dependency

  • Qwen3.5-397B A17B NVFP4 checkpoint available through B12X_MODEL_PATH for the end-to-end MoE benchmark

  • b12x.attention

    • Primary SM120 paged attention backend (split-KV, BF16/FP8 KV, exact host planning)
  • b12x.cute

    • Low-level CuTe and FP4 helpers
  • b12x.gemm

    • Standalone dense NVFP4 GEMM
  • b12x.integration

    • Public runtime entrypoints: b12x_moe_fp4, b12x_paged_attention_forward, create_paged_attention_plan
  • b12x.moe.fused

    • Static, micro, and dynamic fused MoE kernels and reference paths
  • b12x.quant

    • Torch-side NVFP4 packing and quantization helpers
  • b12x.sglang

    • Thin sglang integration shims

Paged attention now routes through the primary b12x.attention.paged backend. It is narrow by design and tuned for the Blackwell serving matrix this repo cares about.

  • b12x.integration.attention.create_paged_attention_plan builds an exact-shape launch plan for one paged attention configuration.
  • allocate_paged_attention_workspace_for_plan allocates reusable scratch buffers for a plan.
  • allocate_paged_attention_workspace_pool provides a caller-owned pool that partitions scratch by CUDA stream for variable-shape workloads.
  • b12x_paged_attention_forward executes the kernel given a plan and workspace.
  • Page size is fixed at 64.
  • Supported KV dtypes: BF16, FP16, FP8 E4M3.
  • FP8 KV uses raw-byte staging with in-kernel descale. Per-head descale tensors are required for FP8 KV.
  • Split-KV chunking is automatic by default. fixed_split_size pins chunk size in pages for exact-shape benchmarking or graph replay.
  • GQA with arbitrary ratios is supported.
  • During CUDA graph capture, output= must be caller-owned and stable across replays.

The paged attention planner, split/merge structure, and benchmark methodology were developed by studying FlashInfer's paged attention kernels. b12x ships its own SM120-first implementation and does not depend on FlashInfer at runtime.

  • b12x.integration.tp_moe.b12x_moe_fp4 requires a caller-owned workspace.
  • b12x selects its fused MoE backend from shape alone:
    • compact routed workloads use the static or micro backend
    • all larger routed workloads use dynamic
  • Use allocate_tp_moe_workspace(...) for one caller-owned launch workspace.
  • Use allocate_tp_moe_workspace_pool() for variable-size, chunked, or eager routing-aware dynamic workloads.
  • Keep one workspace pool per process/device, and let the pool partition scratch by CUDA stream internally.
  • During CUDA graph capture, output= must also be caller-owned and stable across replays.
Variable Default Description
B12X_ATTN=TURBO off Enable MXFP8 PV accumulation for FP8 KV configs (higher throughput, slight accuracy trade-off).
B12X_FAST_MATH 1 Enable fast-math MoE paths.
B12X_MODEL_PATH Path to Qwen3.5 NVFP4 checkpoint for end-to-end MoE benchmarks.
B12X_STATIC_COMPACT_CUTOVER_PAIRS auto Override static→dynamic cutover threshold (routed pairs).
B12X_MICRO_CUTOVER_TOKENS auto Override micro→static cutover threshold (tokens).
B12X_DYNAMIC_ENABLE_MULTICTA 1 Enable multi-CTA dynamic launches.
B12X_DYNAMIC_CHUNK_MULTIPLIER 1 Dynamic backend chunk size multiplier.
B12X_{STATIC,MICRO,DYNAMIC}_MAX_ACTIVE_CLUSTERS auto Override max active clusters per backend.
B12X_{STATIC,MICRO,DYNAMIC}_REUSE_COMPILED 1 Reuse compiled kernels across shapes within a backend.
  • benchmarks/benchmark_moe.py

    • End-to-end Qwen3.5-397B TP=4 MoE benchmark
    • micro batch profile: [1, 2, 4, 8]
    • eager-prefill batch profile: [16384, 32768]
    • sglang-single-request batch profile: [1, 23, 80]
    • chunked-prefill batch profile: [8192, 16384, 24576, 32768]
  • benchmarks/benchmark_paged_attention.py

    • Paged attention vs FlashInfer across decode and extend shapes
  • scripts/sweep_decode_graph_policy.py

    • Two-stage decode graph tuning sweep:
      • stage 1 picks graph_ctas_per_sm for a specific batch size
      • stage 2 fills the dense per-page chunk table for that same batch size
  • benchmarks/benchmark_dense_gemm.py

    • Dense FP4 GEMM vs FlashInfer CUTLASS with CUDA graph replay on GLM-5.1 dense MLP TP=8 shapes
  • benchmarks/benchmark_mxfp8_pv.py

    • MXFP8 PV microbenchmark (turbo mode throughput)
  • tests/test_paged_attention_workspace_api.py

    • Public paged attention plan, workspace, and wrapper correctness
  • tests/test_attention_cuda_graphs.py

    • CUDA graph capture and replay for paged attention, including FP8 KV and small GQA ratios
  • tests/test_attention_paged_forward.py

    • Primary paged forward kernel exactness against the reference path
  • tests/test_attention_paged_merge.py

    • Persistent split-merge exactness
  • tests/test_attention_paged_planner.py

    • Exact host plan metadata and explicit chunk-table coverage
  • tests/test_attention_paged_traits.py

    • Forward-trait selection for the supported serving families
  • tests/test_tp_moe_reference.py

    • Independent oracle-backed MoE correctness test
  • tests/test_moe_equivalence.py

    • Real-weight smoke and CUDA-graph replay routing-safety checks
  • tests/test_gemm_stack.py

    • Dense GEMM exactness vs FlashInfer/cuDNN
B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --no-cuda-graph

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --include-routing

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --batch-size-profile sglang-single-request

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --batch-size-profile eager-prefill --no-cuda-graph

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --batch-size-profile chunked-prefill

B12X_MODEL_PATH=/path/to/Qwen3.5-397B-A17B-NVFP4 python benchmarks/benchmark_moe.py --graph-mode multi-layer --reference none --validate none

python benchmarks/benchmark_paged_attention.py

python benchmarks/benchmark_dense_gemm.py

pytest tests/test_attention_cuda_graphs.py tests/test_paged_attention_workspace_api.py

python tests/test_tp_moe_reference.py --impls b12x --scale-contract per-expert

pytest tests/test_moe_equivalence.py

Decode CTA/Chunk Sweep

Use scripts/sweep_decode_graph_policy.py to tune decode graph policy for one specific batch size. The batch size is controlled by --batch-list; if you pass a single value, the sweep only runs that one batch.

Environment:

source ~/projects/sglang/.venv/bin/activate
export CUTE_DSL_ARCH=sm_120a

Template:

python scripts/sweep_decode_graph_policy.py \
  --kv-dtype bf16 \
  --batch-list <batch_size> \
  --page-start 1 \
  --page-stop 4096 \
  --capture-page-count 4096 \
  --candidate-ctas-per-sm 1,16 \
  --candidate-splits 1,512 \
  --parallel-workers 8 \
  --replays 100 \
  --probe-batch-replays 10 \
  --ci-level 0.99 \
  --output /tmp/<kv_dtype>_decode_graph_policy_bs<batch_size>.json \
  --summary

Example for batch size 1:

python scripts/sweep_decode_graph_policy.py \
  --kv-dtype bf16 \
  --batch-list 1 \
  --page-start 1 \
  --page-stop 4096 \
  --capture-page-count 4096 \
  --candidate-ctas-per-sm 1,16 \
  --candidate-splits 1,512 \
  --parallel-workers 8 \
  --replays 100 \
  --probe-batch-replays 10 \
  --ci-level 0.99 \
  --output /tmp/bf16_decode_graph_policy_bs1.json \
  --summary

Notes:

  • The sweep writes the final combined result to --output.
  • It also writes an incremental JSONL checkpoint next to it as *.checkpoint.jsonl.
  • If you already know the decode CTA and only want the dense chunk fill, add --fixed-cta <n> to skip the CTA search stage.

After the sweep finishes, generate the registered tuning module with:

python scripts/generate_decode_policy_tuning.py \
  --input /tmp/<kv_dtype>_decode_graph_policy_bs<batch_size>.json

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

b12x-0.8.2.tar.gz (323.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

b12x-0.8.2-py3-none-any.whl (367.2 kB view details)

Uploaded Python 3

File details

Details for the file b12x-0.8.2.tar.gz.

File metadata

  • Download URL: b12x-0.8.2.tar.gz
  • Upload date:
  • Size: 323.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for b12x-0.8.2.tar.gz
Algorithm Hash digest
SHA256 eebb711c931fcd79cff4fdef7b6d27082c1f8a97bea560c6c70835e12aa90013
MD5 0fb2eeba532e81feedc167ac4c939b5b
BLAKE2b-256 4464ce3be38a388438da54b71e7f9aca3eaf2d60d766ab2f7371b540d62c883f

See more details on using hashes here.

File details

Details for the file b12x-0.8.2-py3-none-any.whl.

File metadata

  • Download URL: b12x-0.8.2-py3-none-any.whl
  • Upload date:
  • Size: 367.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for b12x-0.8.2-py3-none-any.whl
Algorithm Hash digest
SHA256 42fda0d78ed08d95a17650ebc463017eed491916f61f11d0b4b896e9a29a7f6b
MD5 ded0f8d252de5cc3b0dadcbf1d66a413
BLAKE2b-256 0e330cf3276b3408ede6e0a14055ffa44aae841a78284e19dc392616514a6f3d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page