Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and persistent Triton kernels
Project description
gru-qat
Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and a multi-step persistent Triton kernel.
- Why: cuDNN's GRU is a closed kernel; we cannot insert fake-quant. To do QAT with arbitrary quantization granularities (per-channel, per-group, fine-grained int4) we own the cell.
- What: a manually-unrolled GRU cell where every quantizable
quantity is a
FakeQuantizemodule that can be swapped without touching the cell code. Reference path is pure PyTorch; accelerated path is Triton. - Plus: hidden weights can be parameterized as Diagonal (one
vector per gate), Monarch (block-diagonal), Butterfly (
O(H log H)twiddle), Circulant, or LDR (low-displacement rank) structured matrices, with matching Triton kernels for Diagonal, Monarch and Butterfly.
Read first
SCOPE.md— what's in, what's out, key design decisions.DEVELOPMENT.md— file map, phase status, bench numbers, upgrade pathways.
Install
pip install gru-qat # core: dense reference + QAT (CPU/CUDA)
pip install "gru-qat[triton]" # + persistent Triton fast path (CUDA)
pip install "gru-qat[structured]" # + Monarch / Butterfly / LDR (torch-structured)
The core package depends only on torch and numpy. triton and
torch-structured are
optional extras — the dense reference path, dense QAT, and the
diagonal/circulant kinds work without either. structure.py imports
torch-structured lazily, so kind="dense" never touches it. The
structured extra requires torch-structured>=1.2.1.
From source (development)
uv sync
uv pip install -e ".[dev]" # tests, mypy, ruff
pytest -q
Dense QAT layer
import torch
from gru_qat import GRULayer, PRESETS
layer = GRULayer(
input_size=512, hidden_size=512,
recipe=PRESETS["int8_per_channel"],
gate_layout="fused",
pre_batch_input=True, # one big GEMM for x @ W_i across T
compile_step=True, # torch.compile fuses the elementwise body
).cuda()
out, h_T = layer(x, h0)
Triton-accelerated dense (persistent kernel)
from gru_qat.triton_kernels.scan import gru_scan_persistent
# Inside training loop:
w = layer.cell.quantize_weights()
gi = layer.cell.input_projection(x, w)
out = gru_scan_persistent(gi, h0, w.Wh_cat, w.bh_cat)
Structured hidden weights with Triton (Monarch — fastest)
from gru_qat import GRULayer, QuantRecipe, QuantizerConfig, StructureConfig
layer = GRULayer(
input_size=512, hidden_size=512,
recipe=QuantRecipe(
weight=QuantizerConfig(bits=32, axis=0, name="W_id"),
input_act=QuantizerConfig(bits=32, name="x_id"),
hidden=QuantizerConfig(bits=8, name="h_q"), # int8 hidden quant
),
gate_layout="fused",
structure_hidden=StructureConfig(kind="monarch", nblocks=8),
use_triton="auto", # routes through the persistent monarch kernel
).cuda()
# QAT flow:
for x in train_loader:
out, _ = layer(x)
loss = ...
loss.backward()
layer.calibrate(val_loader, n_batches=64)
layer.freeze()
out, _ = layer(x) # now runs through the Triton kernel with frozen scales
Structured hidden weights — Butterfly (smallest params)
from gru_qat import GRULayer, StructureConfig
layer = GRULayer(
H, H, recipe=...,
gate_layout="fused",
structure_hidden=StructureConfig(kind="butterfly"),
use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.
Structured hidden weights — Diagonal (smallest & fastest)
from gru_qat import GRULayer, StructureConfig
layer = GRULayer(
H, H, recipe=...,
gate_layout="fused",
structure_hidden=StructureConfig(kind="diagonal"),
use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.
kind="diagonal" collapses each H*H hidden matrix to a length-H
vector. Per-step recurrence becomes elementwise w_h * h instead of a
matmul — 3H weight scalars total on the hidden side, O(H) FLOPs per
step. The persistent Triton kernel has no matmul on the hidden side,
no cross-program reduction, and runs fully in registers across the
T-loop. Good fit when you want a very small recurrence (e.g. for an
embedded model) and are happy treating the hidden update as
hidden-unit-independent (similar in spirit to IndRNN / diagonal SSMs).
Status
All originally-planned phases (0–5) complete. The dense and structured paths are feature-complete:
| feature | status |
|---|---|
| STE primitives + FakeQuantize variants | ✓ |
Dense GRUCellQuant parity vs nn.GRUCell (< 1e-5) |
✓ |
| Fake-quant insertion in cell (all 6 weight + 3 activation points) | ✓ |
GRULayer with calibration → freeze flow |
✓ |
| Triton multi-step persistent kernel (dense, fp32, fp32 + frozen int8 QAT) | ✓ |
| Structured hidden weights (Diagonal / Monarch / Butterfly / Circulant / LDR) | ✓ |
| Triton persistent kernel for Diagonal (fp32 + QAT) | ✓ |
| Triton persistent kernel for Monarch (fp32 + QAT) | ✓ |
| Triton persistent kernel for Butterfly (fp32 + QAT) | ✓ |
The suite spans parity, QAT, calibration, structured-matrix, and
per-kernel strict numerical tests (CUDA-only Triton tests skip
automatically when no GPU is present). Run pytest -q for the full set,
or pytest -m "not slow" to skip the long-T parity sweeps.
Train-step speed at (T=64, B=32, H=512) — fp32
| variant | ms/iter | vs cuDNN |
|---|---|---|
cuDNN nn.GRU (dense, no quant) |
4.4 | 1.0× |
GRULayer dense + torch.compile |
38.7 | 8.8× |
| dense Triton persistent | 8.8 | 1.9× |
| Monarch persistent (nblocks=4) | 5.8 | 1.3× |
| Monarch persistent (nblocks=8) | 2.0 | 0.45× (2.2× faster) |
| Butterfly persistent | 20.3 | 4.6× |
| Diagonal persistent | ~1.1 | ~0.25× (4× faster) |
For QAT (frozen int8 hidden), expect ~10–30% overhead on top of the fp32 number depending on path.
Numerical parity vs PyTorch reference at (T=64, B=32, H=512)
Measured at bench shape against the GRULayer(use_triton=False)
PyTorch reference path (which itself matches torch.nn.GRUCell to
< 1e-5). Both sides use torch.set_float32_matmul_precision("high")
so the Triton kernels and PyTorch's matmul see the same TF32 inputs.
"fwd"/"dx"/"dh0" are max relative diffs; "weight-grad" is the worst
per-parameter dWh / twiddle / b_h* rel diff.
| variant | regime | fwd | dx | dh0 | weight-grad |
|---|---|---|---|---|---|
| Dense Triton persistent | fp32 | 4e-4 | 4e-4 | 8e-4 | 1e-3 |
| Dense Triton persistent | int8 QAT (hidden) | 8% | 7% | 9% | — |
| Monarch persistent, nb=4 | fp32 | 3e-4 | 5e-4 | 7e-4 | 2e-3 |
| Monarch persistent, nb=4 | int8 QAT (hidden) | 8% | 7% | 6% | 3% |
| Monarch persistent, nb=8 | fp32 | 2e-4 | 4e-4 | 6e-4 | 2e-3 |
| Monarch persistent, nb=8 | int8 QAT (hidden) | 8% | 5% | 8% | 3% |
| Butterfly persistent | fp32 | 3e-2 | 3e-3 | 1e-3 | 2e-3 |
| Butterfly persistent | int8 QAT (hidden) | 15% | 15% | 1% | 8% |
| Diagonal persistent | fp32 | 1e-6 | 4e-5 | 2e-7 | 2e-6 |
| Diagonal persistent | int8 QAT (hidden) | 0 | 3e-5 | 2e-7 | 1e-6 |
QAT rows for the matmul-based variants show ~5–15% relative drift
because each step's round(x/scale) flips at half-integer boundaries
when tl.dot's TF32 reduction order disagrees with cuBLAS by
O(scale) — a single rounding flip per ~100 positions, amplified by
the recurrence over T=64 steps. Not a kernel bug: torch.round and
tl.extra.libdevice.rint are bit-identical on the same fp32 input
(verified across 1M values + half-integer perturbations). Forcing
Triton to input_precision="ieee" would tighten the QAT rows to
~1e-3 at the cost of ~2-4× slower matmul; current choice is speed
over bit-parity. The butterfly fp32 row's 3e-2 fwd is the same
story (kernel TF32 vs torch_structured's CUDA op).
The diagonal variant has no matmul on the hidden side, so it ducks this story entirely: every multiplication is elementwise and Triton emits the same FMA order as the PyTorch reference. The QAT fwd row is exactly bit-identical (rel diff = 0); fp32 and grad rows are at fp32 machine precision (~1e-5 / ~1e-7).
Layout
src/gru_qat/
__init__.py public API
ste.py STE autograd functions
quantizers.py FakeQuantize + observers
calibration.py calibrate(module, loader, n_batches)
structure.py StructureConfig + make_structured_linear
gru_cell.py GRUCellQuant (single step, optionally structured)
gru_layer.py GRULayer (multi-step + Triton dispatch)
triton_kernels/
scan.py dense persistent fwd+bwd kernels
scan_diagonal.py Diagonal persistent fwd+bwd kernels
scan_monarch.py Monarch persistent fwd+bwd kernels
scan_butterfly.py Butterfly persistent fwd+bwd kernels
See DEVELOPMENT.md for the file-by-file design
and the per-phase commit history.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file gru_qat-0.0.1.tar.gz.
File metadata
- Download URL: gru_qat-0.0.1.tar.gz
- Upload date:
- Size: 956.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
675a041d8e5e317ff2d268ccd3684ead1cab4fbcecb74cd09af27e45a732ae3a
|
|
| MD5 |
ec13d099a81aa5fa3b1417c74534e0d7
|
|
| BLAKE2b-256 |
1ec4c612fc2b4ff113f7c93f83139583c64e81ad590e3bd871e39274a49aa3da
|
File details
Details for the file gru_qat-0.0.1-py3-none-any.whl.
File metadata
- Download URL: gru_qat-0.0.1-py3-none-any.whl
- Upload date:
- Size: 69.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.11
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9861b47414e4243c5f91ad4e8c2c847b0ebffe1f96ea4627b7dbf2b204872bbb
|
|
| MD5 |
a169ccb6837e96a6874963053e3bf320
|
|
| BLAKE2b-256 |
d4fae99d017d33c78b0895401cd382e33d021f3cc4810af2ae2347376796b931
|