Lightweight runtime correctness checker for custom CUDA/Triton kernels via statistical sampling and outlier-aware comparison.
Project description
cuda-kernel-verifier
Runtime correctness checker for custom CUDA / Triton kernels - ~200 lines of logic.
Attach a single decorator to any forward function and the library will periodically re-run the same inputs through a known-correct implementation in a background thread, comparing results with torch.allclose. Zero impact on the training graph. Works with raw kernels, Triton ops, torch.autograd.Function, or any nn.Module, including models and layers compiled with torch.compile. The enqueue call is decorated with @torch.compiler.disable so it is always a clean graph break with no interference with compiled regions.
How it works
forward(x) ──► kernel result ──► returned to caller immediately
│
▼ (background thread, non-blocking)
outlier check
│
┌──────┴──────┐
│ outlier? │ not outlier?
│ │
▼ ▼
enqueue random sample gate
(execution_sample_probability)
│
▼
ground_truth(x)
│
▼
torch.allclose?
│ │
yes no
│ │
discard failure_callback(...)
Sampling
The checker does not run the ground truth on every call. That would negate the point of writing a fast kernel. Instead, each call passes through two gates before work is enqueued:
- Outlier gate - if the current input is detected as an outlier (see below), it is enqueued unconditionally, so unusual inputs are never skipped.
- Random gate - otherwise, the call is enqueued with probability
execution_sample_probability(default0.5). Tune this down for large models where verification overhead matters.
The comparison itself runs in a single daemon background thread so the main training loop is never blocked. You can adjust the sampling rate at any point during a run with EquivalenceChecker.set_execution_sample_probability(p), or stop verification entirely with EquivalenceChecker.stop().
Outlier detection
ExponentialRunningCentroidExecutionOutlierDetector tracks the distribution of activations seen so far and flags batches that look statistically different from the norm.
Algorithm:
-
Maintain a running centroid via exponential moving average:
centroid ← α · mean(batch) + (1 − α) · centroidDefaultα = 0.01(slow drift, stable reference). -
Compute the L2 distance of each sample in the batch from the centroid.
-
Append distances to a rolling window of up to
max_distancesvalues (default 10 000). -
A batch is an outlier when:
mean(distances) / quantile(all_distances, p) ≥ outlier_thresholdDefaultp = 0.95,outlier_threshold = 0.8. -
The first batch is always treated as an outlier so the centroid can be seeded before any comparison.
This means the verifier is biased toward checking inputs that are unusual (the cases most likely to expose a kernel bug) while randomly sampling the rest.
Installation
Requires CUDA Install PyTorch for CUDA first, then the package:
pip install torch --index-url https://download.pytorch.org/whl/cu126
pip install cuda-kernel-verifier
Quick start
import torch
from cuda_kernel_verifier import equivalent, EquivalenceChecker
def ground_truth(x: torch.Tensor) -> torch.Tensor:
return x.sum(dim=1)
def on_mismatch(args: FailureCallbackArgs) -> None:
diff = (args.original_result - args.ground_truth_result).abs().max().item()
raise AssertionError(f"Kernel diverged! max abs diff = {diff:.6f}")
@equivalent(ground_truth, on_mismatch, rtol=1e-1, atol=1e-6)
def my_fast_row_sum(x: torch.Tensor) -> torch.Tensor:
return my_cuda_row_sum_kernel(x)
EquivalenceChecker.start(execution_sample_probability=0.5)
result = my_fast_row_sum(torch.randn(128, 512, device="cuda"))
EquivalenceChecker.stop()
Attaching to torch.autograd.Function
from torch.autograd import Function
from cuda_kernel_verifier import equivalent, FailureCallbackArgs
def sum_ground_truth(ctx, x):
return x.sum(dim=1)
def on_mismatch(args: FailureCallbackArgs) -> None:
raise AssertionError("kernel diverged!")
class FastRowSum(Function):
@staticmethod
@equivalent(sum_ground_truth, on_mismatch, rtol=1e-1, atol=1e-6)
def forward(ctx, x):
ctx.save_for_backward(x)
return my_cuda_kernel(x)
The decorator wraps the static method, so ctx is passed through transparently. Just mirror the full signature in the ground truth and ignore ctx with _ if needed.
Custom outlier detector
from cuda_kernel_verifier import (
equivalent,
ExponentialRunningCentroidExecutionOutlierDetector,
)
detector = ExponentialRunningCentroidExecutionOutlierDetector(
percentile=0.99,
outlier_threshold=0.9,
exponential_alpha=5e-3,
)
@equivalent(ground_truth, outlier_detector=detector)
def my_kernel(x):
...
API reference
equivalent(ground_truth_function, failure_callback=None, *, rtol=1e-2, atol=1e-8, outlier_detector=None)
Decorator factory. Returns a decorator that wraps the target function.
| Parameter | Description |
|---|---|
ground_truth_function |
Known-correct implementation with the same signature. |
failure_callback |
Called with FailureCallbackArgs on mismatch. Required. |
rtol |
Relative tolerance for torch.allclose (default 1e-2). |
atol |
Absolute tolerance for torch.allclose (default 1e-8). |
outlier_detector |
Outlier strategy. Defaults to ExponentialRunningCentroidExecutionOutlierDetector. |
EquivalenceChecker
Class-level singleton that manages the background thread and queue.
| Method | Description |
|---|---|
start(max_execution_queue_size=0, execution_sample_probability=0.5) |
Start the background thread. Resets all outlier detectors. |
stop() |
Stop the thread and drain the queue. |
is_running() |
Returns True if the checker is active. |
set_execution_sample_probability(p) |
Adjust sampling rate at runtime. |
ExponentialRunningCentroidExecutionOutlierDetector
| Parameter | Default | Description |
|---|---|---|
percentile |
0.95 |
Quantile used as the distance scale reference. |
max_distances |
10_000 |
Rolling window size for historical distances. |
exponential_alpha |
1e-2 |
EMA factor for the running centroid. |
outlier_threshold |
0.8 |
Fraction of the percentile scale that triggers outlier classification. |
FailureCallbackArgs
Dataclass passed to the failure callback.
| Field | Type | Description |
|---|---|---|
original_result |
torch.Tensor |
Output of the kernel under test (detached). |
ground_truth_result |
torch.Tensor |
Output of the reference function. |
Full example
See examples/mnist_triton.py for a complete MNIST training loop using a Triton row-sum kernel validated in real time.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cuda_kernel_verifier-1.0.2.tar.gz.
File metadata
- Download URL: cuda_kernel_verifier-1.0.2.tar.gz
- Upload date:
- Size: 8.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
838444b2cd0130d4b71c47fe7a9ca8511bc06e4149dcde30c1794a1649cbc02c
|
|
| MD5 |
91ffd05301b5b2ff25a2e62f9ea23d2f
|
|
| BLAKE2b-256 |
49179fc3a202b62cc930014ea47f34fade9fa4bb158a77874a859540c25a0eca
|
File details
Details for the file cuda_kernel_verifier-1.0.2-py3-none-any.whl.
File metadata
- Download URL: cuda_kernel_verifier-1.0.2-py3-none-any.whl
- Upload date:
- Size: 8.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3b417c2722564e4a46f09415b660bcfaa23c5f3e328a1f22359d94ae4564b783
|
|
| MD5 |
f16cbf32016b9da905f21b5b739c9961
|
|
| BLAKE2b-256 |
58b4a9598adeb5bd3ec68945a146b1f4ffce1c3f90a0876341d403b04f40f812
|