Skip to main content

Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and persistent Triton kernels

Project description

gru-qat

Pluggable QAT and quantized inference for GRU, with structured-matrix hidden weights and a multi-step persistent Triton kernel.

  • Why: cuDNN's GRU is a closed kernel; we cannot insert fake-quant. To do QAT with arbitrary quantization granularities (per-channel, per-group, fine-grained int4) we own the cell.
  • What: a manually-unrolled GRU cell where every quantizable quantity is a FakeQuantize module that can be swapped without touching the cell code. Reference path is pure PyTorch; accelerated path is Triton.
  • Plus: hidden weights can be parameterized as Diagonal (one vector per gate), Monarch (block-diagonal), Butterfly (O(H log H) twiddle), Circulant, or LDR (low-displacement rank) structured matrices, with matching Triton kernels for Diagonal, Monarch and Butterfly.

Read first

  1. SCOPE.md — what's in, what's out, key design decisions.
  2. DEVELOPMENT.md — file map, phase status, bench numbers, upgrade pathways.

Install

pip install gru-qat                 # core: dense reference + QAT (CPU/CUDA)
pip install "gru-qat[triton]"       # + persistent Triton fast path (CUDA)
pip install "gru-qat[structured]"   # + Monarch / Butterfly / LDR (torch-structured)

The core package depends only on torch and numpy. triton and torch-structured are optional extras — the dense reference path, dense QAT, and the diagonal/circulant kinds work without either. structure.py imports torch-structured lazily, so kind="dense" never touches it. The structured extra requires torch-structured>=1.2.1.

From source (development)

uv sync
uv pip install -e ".[dev]"   # tests, mypy, ruff
pytest -q

Dense QAT layer

import torch
from gru_qat import GRULayer, PRESETS

layer = GRULayer(
    input_size=512, hidden_size=512,
    recipe=PRESETS["int8_per_channel"],
    gate_layout="fused",
    pre_batch_input=True,        # one big GEMM for x @ W_i across T
    compile_step=True,           # torch.compile fuses the elementwise body
).cuda()
out, h_T = layer(x, h0)

Triton-accelerated dense (persistent kernel)

from gru_qat.triton_kernels.scan import gru_scan_persistent

# Inside training loop:
w = layer.cell.quantize_weights()
gi = layer.cell.input_projection(x, w)
out = gru_scan_persistent(gi, h0, w.Wh_cat, w.bh_cat)

Structured hidden weights with Triton (Monarch — fastest)

from gru_qat import GRULayer, QuantRecipe, QuantizerConfig, StructureConfig

layer = GRULayer(
    input_size=512, hidden_size=512,
    recipe=QuantRecipe(
        weight=QuantizerConfig(bits=32, axis=0, name="W_id"),
        input_act=QuantizerConfig(bits=32, name="x_id"),
        hidden=QuantizerConfig(bits=8, name="h_q"),       # int8 hidden quant
    ),
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="monarch", nblocks=8),
    use_triton="auto",   # routes through the persistent monarch kernel
).cuda()

# QAT flow:
for x in train_loader:
    out, _ = layer(x)
    loss = ...
    loss.backward()

layer.calibrate(val_loader, n_batches=64)
layer.freeze()
out, _ = layer(x)        # now runs through the Triton kernel with frozen scales

Structured hidden weights — Butterfly (smallest params)

from gru_qat import GRULayer, StructureConfig

layer = GRULayer(
    H, H, recipe=...,
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="butterfly"),
    use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.

Structured hidden weights — Diagonal (smallest & fastest)

from gru_qat import GRULayer, StructureConfig

layer = GRULayer(
    H, H, recipe=...,
    gate_layout="fused",
    structure_hidden=StructureConfig(kind="diagonal"),
    use_triton="auto",
).cuda()
# Same calibrate -> freeze -> forward flow.

kind="diagonal" collapses each H*H hidden matrix to a length-H vector. Per-step recurrence becomes elementwise w_h * h instead of a matmul — 3H weight scalars total on the hidden side, O(H) FLOPs per step. The persistent Triton kernel has no matmul on the hidden side, no cross-program reduction, and runs fully in registers across the T-loop. Good fit when you want a very small recurrence (e.g. for an embedded model) and are happy treating the hidden update as hidden-unit-independent (similar in spirit to IndRNN / diagonal SSMs).

Status

All originally-planned phases (0–5) complete. The dense and structured paths are feature-complete:

feature status
STE primitives + FakeQuantize variants
Dense GRUCellQuant parity vs nn.GRUCell (< 1e-5)
Fake-quant insertion in cell (all 6 weight + 3 activation points)
GRULayer with calibration → freeze flow
Triton multi-step persistent kernel (dense, fp32, fp32 + frozen int8 QAT)
Structured hidden weights (Diagonal / Monarch / Butterfly / Circulant / LDR)
Triton persistent kernel for Diagonal (fp32 + QAT)
Triton persistent kernel for Monarch (fp32 + QAT)
Triton persistent kernel for Butterfly (fp32 + QAT)

The suite spans parity, QAT, calibration, structured-matrix, and per-kernel strict numerical tests (CUDA-only Triton tests skip automatically when no GPU is present). Run pytest -q for the full set, or pytest -m "not slow" to skip the long-T parity sweeps.

Train-step speed at (T=64, B=32, H=512) — fp32

variant ms/iter vs cuDNN
cuDNN nn.GRU (dense, no quant) 4.4 1.0×
GRULayer dense + torch.compile 38.7 8.8×
dense Triton persistent 8.8 1.9×
Monarch persistent (nblocks=4) 5.8 1.3×
Monarch persistent (nblocks=8) 2.0 0.45× (2.2× faster)
Butterfly persistent 20.3 4.6×
Diagonal persistent ~1.1 ~0.25× (4× faster)

For QAT (frozen int8 hidden), expect ~10–30% overhead on top of the fp32 number depending on path.

Numerical parity vs PyTorch reference at (T=64, B=32, H=512)

Measured at bench shape against the GRULayer(use_triton=False) PyTorch reference path (which itself matches torch.nn.GRUCell to < 1e-5). Both sides use torch.set_float32_matmul_precision("high") so the Triton kernels and PyTorch's matmul see the same TF32 inputs. "fwd"/"dx"/"dh0" are max relative diffs; "weight-grad" is the worst per-parameter dWh / twiddle / b_h* rel diff.

variant regime fwd dx dh0 weight-grad
Dense Triton persistent fp32 4e-4 4e-4 8e-4 1e-3
Dense Triton persistent int8 QAT (hidden) 8% 7% 9%
Monarch persistent, nb=4 fp32 3e-4 5e-4 7e-4 2e-3
Monarch persistent, nb=4 int8 QAT (hidden) 8% 7% 6% 3%
Monarch persistent, nb=8 fp32 2e-4 4e-4 6e-4 2e-3
Monarch persistent, nb=8 int8 QAT (hidden) 8% 5% 8% 3%
Butterfly persistent fp32 3e-2 3e-3 1e-3 2e-3
Butterfly persistent int8 QAT (hidden) 15% 15% 1% 8%
Diagonal persistent fp32 1e-6 4e-5 2e-7 2e-6
Diagonal persistent int8 QAT (hidden) 0 3e-5 2e-7 1e-6

QAT rows for the matmul-based variants show ~5–15% relative drift because each step's round(x/scale) flips at half-integer boundaries when tl.dot's TF32 reduction order disagrees with cuBLAS by O(scale) — a single rounding flip per ~100 positions, amplified by the recurrence over T=64 steps. Not a kernel bug: torch.round and tl.extra.libdevice.rint are bit-identical on the same fp32 input (verified across 1M values + half-integer perturbations). Forcing Triton to input_precision="ieee" would tighten the QAT rows to ~1e-3 at the cost of ~2-4× slower matmul; current choice is speed over bit-parity. The butterfly fp32 row's 3e-2 fwd is the same story (kernel TF32 vs torch_structured's CUDA op).

The diagonal variant has no matmul on the hidden side, so it ducks this story entirely: every multiplication is elementwise and Triton emits the same FMA order as the PyTorch reference. The QAT fwd row is exactly bit-identical (rel diff = 0); fp32 and grad rows are at fp32 machine precision (~1e-5 / ~1e-7).

Layout

src/gru_qat/
  __init__.py             public API
  ste.py                  STE autograd functions
  quantizers.py           FakeQuantize + observers
  calibration.py          calibrate(module, loader, n_batches)
  structure.py            StructureConfig + make_structured_linear
  gru_cell.py             GRUCellQuant (single step, optionally structured)
  gru_layer.py            GRULayer (multi-step + Triton dispatch)
  triton_kernels/
    scan.py               dense persistent fwd+bwd kernels
    scan_diagonal.py      Diagonal persistent fwd+bwd kernels
    scan_monarch.py       Monarch persistent fwd+bwd kernels
    scan_butterfly.py     Butterfly persistent fwd+bwd kernels

See DEVELOPMENT.md for the file-by-file design and the per-phase commit history.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gru_qat-0.2.0.tar.gz (983.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

gru_qat-0.2.0-py3-none-any.whl (70.2 kB view details)

Uploaded Python 3

File details

Details for the file gru_qat-0.2.0.tar.gz.

File metadata

  • Download URL: gru_qat-0.2.0.tar.gz
  • Upload date:
  • Size: 983.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for gru_qat-0.2.0.tar.gz
Algorithm Hash digest
SHA256 290d241e6bc2c2740467057425dbd7219fce28d46dbed91aca7439ed80508edc
MD5 89777557cdaea888c1058e5fc0ceb2b5
BLAKE2b-256 d6a2d1dedf4369379fdb71e24206cab46d1b1a4efaa04e5e64e5e04f123a368b

See more details on using hashes here.

File details

Details for the file gru_qat-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: gru_qat-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 70.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.11

File hashes

Hashes for gru_qat-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 4d1522eb50f104da80a766fd2297904d49b4ee8f1da342a19012e39ce4da6788
MD5 5feeedee2f8ec0b895e92cba6fc4dd97
BLAKE2b-256 da6380e06879c6c05291078041d9399539c44e1357fe080a0cb3e44707f2cbfc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page