Skip to main content

Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and persistent Triton kernels

Project description

gru-qat

Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and a multi-step persistent Triton kernel.

  • Why: cuDNN's GRU is a closed kernel; we cannot insert fake-quant. To do QAT with arbitrary quantization granularities (per-channel, per-group, fine-grained int4) we own the cell.
  • What: a manually-unrolled GRU cell where every quantizable quantity is a FakeQuantize module that can be swapped without touching the cell code. Reference path is pure PyTorch; accelerated path is Triton.
  • Plus: hidden weights can be parameterized as Diagonal (one vector per gate), Monarch (block-diagonal), Butterfly (O(H log H) twiddle), Circulant, or LDR (low-displacement rank) structured matrices, with matching Triton kernels for Diagonal, Monarch and Butterfly.

Read first

  1. SCOPE.md — what's in, what's out, key design decisions.
  2. DEVELOPMENT.md — file map, phase status, bench numbers, upgrade pathways.

Install

pip install gru-qat                 # core: dense reference + QAT (CPU/CUDA)
pip install "gru-qat[triton]"       # + persistent Triton fast path (CUDA)
pip install "gru-qat[structured]"   # + Monarch / Butterfly / LDR (torch-structured)

The core package depends only on torch and numpy. triton and torch-structured are optional extras — the dense reference path, dense QAT, and the diagonal/circulant kinds work without either. structure.py imports torch-structured lazily, so kind="dense" never touches it. The structured extra requires torch-structured>=1.2.1.

From source (development)

uv sync
uv pip install -e ".[dev]"   # tests, mypy, ruff
pytest -q

Dense QAT layer

import torch
from gru_qat import GRULayer, PRESETS

layer = GRULayer(
    input_size=512, hidden_size=512,
    recipe=PRESETS["int8_per_channel"],
    gate_layout="fused",
    pre_batch_input=True,        # one big GEMM for x @ W_i across T
    compile_step=True,           # torch.compile fuses the elementwise body
).cuda()
out, h_T = layer(x, h0)

Triton-accelerated dense (persistent kernel)

from gru_qat.triton_kernels.scan import gru_scan_persistent

# Inside training loop:
w = layer.cell.quantize_weights()
gi = layer.cell.input_projection(x, w)
out = gru_scan_persistent(gi, h0, w.Wh_cat, w.bh_cat)

Structured hidden weights with Triton (Monarch — fastest)

from gru_qat import GRULayer, QuantRecipe, QuantizerConfig, StructureConfig

layer = GRULayer(
    input_size=512, hidden_size=512,
    recipe=QuantRecipe(
        weight=QuantizerConfig(bits=32, axis=0, name="W_id"),
        input_act=QuantizerConfig(bits=32, name="x_id"),
        hidden=QuantizerConfig(bits=8, name="h_q"),       # int8 hidden quant
    ),
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="monarch", nblocks=8),
    use_triton="auto",   # routes through the persistent monarch kernel
).cuda()

# QAT flow:
for x in train_loader:
    out, _ = layer(x)
    loss = ...
    loss.backward()

layer.calibrate(val_loader, n_batches=64)
layer.freeze()
out, _ = layer(x)        # now runs through the Triton kernel with frozen scales

Structured hidden weights — Butterfly (smallest params)

from gru_qat import GRULayer, StructureConfig

layer = GRULayer(
    H, H, recipe=...,
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="butterfly"),
    use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.

Structured hidden weights — Diagonal (smallest & fastest)

from gru_qat import GRULayer, StructureConfig

layer = GRULayer(
    H, H, recipe=...,
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="diagonal"),
    use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.

kind="diagonal" collapses each H*H hidden matrix to a length-H vector. Per-step recurrence becomes elementwise w_h * h instead of a matmul — 3H weight scalars total on the hidden side, O(H) FLOPs per step. The persistent Triton kernel has no matmul on the hidden side, no cross-program reduction, and runs fully in registers across the T-loop. Good fit when you want a very small recurrence (e.g. for an embedded model) and are happy treating the hidden update as hidden-unit-independent (similar in spirit to IndRNN / diagonal SSMs).

Status

All originally-planned phases (0–5) complete. The dense and structured paths are feature-complete:

feature status
STE primitives + FakeQuantize variants
Dense GRUCellQuant parity vs nn.GRUCell (< 1e-5)
Fake-quant insertion in cell (all 6 weight + 3 activation points)
GRULayer with calibration → freeze flow
Triton multi-step persistent kernel (dense, fp32, fp32 + frozen int8 QAT)
Structured hidden weights (Diagonal / Monarch / Butterfly / Circulant / LDR)
Triton persistent kernel for Diagonal (fp32 + QAT)
Triton persistent kernel for Monarch (fp32 + QAT)
Triton persistent kernel for Butterfly (fp32 + QAT)

The suite spans parity, QAT, calibration, structured-matrix, and per-kernel strict numerical tests (CUDA-only Triton tests skip automatically when no GPU is present). Run pytest -q for the full set, or pytest -m "not slow" to skip the long-T parity sweeps.

Train-step speed at (T=64, B=32, H=512) — fp32

variant ms/iter vs cuDNN
cuDNN nn.GRU (dense, no quant) 4.4 1.0×
GRULayer dense + torch.compile 38.7 8.8×
dense Triton persistent 8.8 1.9×
Monarch persistent (nblocks=4) 5.8 1.3×
Monarch persistent (nblocks=8) 2.0 0.45× (2.2× faster)
Butterfly persistent 20.3 4.6×
Diagonal persistent ~1.1 ~0.25× (4× faster)

For QAT (frozen int8 hidden), expect ~10–30% overhead on top of the fp32 number depending on path.

Numerical parity vs PyTorch reference at (T=64, B=32, H=512)

Measured at bench shape against the GRULayer(use_triton=False) PyTorch reference path (which itself matches torch.nn.GRUCell to < 1e-5). Both sides use torch.set_float32_matmul_precision("high") so the Triton kernels and PyTorch's matmul see the same TF32 inputs. "fwd"/"dx"/"dh0" are max relative diffs; "weight-grad" is the worst per-parameter dWh / twiddle / b_h* rel diff.

variant regime fwd dx dh0 weight-grad
Dense Triton persistent fp32 4e-4 4e-4 8e-4 1e-3
Dense Triton persistent int8 QAT (hidden) 8% 7% 9%
Monarch persistent, nb=4 fp32 3e-4 5e-4 7e-4 2e-3
Monarch persistent, nb=4 int8 QAT (hidden) 8% 7% 6% 3%
Monarch persistent, nb=8 fp32 2e-4 4e-4 6e-4 2e-3
Monarch persistent, nb=8 int8 QAT (hidden) 8% 5% 8% 3%
Butterfly persistent fp32 3e-2 3e-3 1e-3 2e-3
Butterfly persistent int8 QAT (hidden) 15% 15% 1% 8%
Diagonal persistent fp32 1e-6 4e-5 2e-7 2e-6
Diagonal persistent int8 QAT (hidden) 0 3e-5 2e-7 1e-6

QAT rows for the matmul-based variants show ~5–15% relative drift because each step's round(x/scale) flips at half-integer boundaries when tl.dot's TF32 reduction order disagrees with cuBLAS by O(scale) — a single rounding flip per ~100 positions, amplified by the recurrence over T=64 steps. Not a kernel bug: torch.round and tl.extra.libdevice.rint are bit-identical on the same fp32 input (verified across 1M values + half-integer perturbations). Forcing Triton to input_precision="ieee" would tighten the QAT rows to ~1e-3 at the cost of ~2-4× slower matmul; current choice is speed over bit-parity. The butterfly fp32 row's 3e-2 fwd is the same story (kernel TF32 vs torch_structured's CUDA op).

The diagonal variant has no matmul on the hidden side, so it ducks this story entirely: every multiplication is elementwise and Triton emits the same FMA order as the PyTorch reference. The QAT fwd row is exactly bit-identical (rel diff = 0); fp32 and grad rows are at fp32 machine precision (~1e-5 / ~1e-7).

Layout

src/gru_qat/
  __init__.py             public API
  ste.py                  STE autograd functions
  quantizers.py           FakeQuantize + observers
  calibration.py          calibrate(module, loader, n_batches)
  structure.py            StructureConfig + make_structured_linear
  gru_cell.py             GRUCellQuant (single step, optionally structured)
  gru_layer.py            GRULayer (multi-step + Triton dispatch)
  triton_kernels/
    scan.py               dense persistent fwd+bwd kernels
    scan_diagonal.py      Diagonal persistent fwd+bwd kernels
    scan_monarch.py       Monarch persistent fwd+bwd kernels
    scan_butterfly.py     Butterfly persistent fwd+bwd kernels

See DEVELOPMENT.md for the file-by-file design and the per-phase commit history.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gru_qat-0.0.1.tar.gz (956.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gru_qat-0.0.1-py3-none-any.whl (69.4 kB view details)

Uploaded Python 3

File details

Details for the file gru_qat-0.0.1.tar.gz.

File metadata

  • Download URL: gru_qat-0.0.1.tar.gz
  • Upload date:
  • Size: 956.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for gru_qat-0.0.1.tar.gz
Algorithm Hash digest
SHA256 675a041d8e5e317ff2d268ccd3684ead1cab4fbcecb74cd09af27e45a732ae3a
MD5 ec13d099a81aa5fa3b1417c74534e0d7
BLAKE2b-256 1ec4c612fc2b4ff113f7c93f83139583c64e81ad590e3bd871e39274a49aa3da

See more details on using hashes here.

File details

Details for the file gru_qat-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: gru_qat-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 69.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for gru_qat-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 9861b47414e4243c5f91ad4e8c2c847b0ebffe1f96ea4627b7dbf2b204872bbb
MD5 a169ccb6837e96a6874963053e3bf320
BLAKE2b-256 d4fae99d017d33c78b0895401cd382e33d021f3cc4810af2ae2347376796b931

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page