Skip to main content

mgqa - Pytorch

Project description



The open source implementation of the multi grouped query attention by the paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"

Paper Link


  • Lucidrains
  • Agorians


pip install mgqa


import torch
from mgqa.transformer import MGQATransformer, Decoder

x = torch.randint(0, 20000, (1, 1024))

model = MGQATransformer(e
    num_tokens = 20000,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 12,
        heads = 8,
        attn_kv_heads = 2 # say you want 4 query heads to attend to 1 key / value head

result = model(x)


  • A potential triton implementation that may or may not work, I don't have gpus to test this out. If it doesn't work and you fix please let me know so we can provide this useful attn
# !pip3 install -U --index-url triton-nightly
# !pip3 install torch

import torch
import triton
import triton.language as tl

def max_fn(x, y):
    return tl.math.max(x, y)

def _fwd_kernel(
    Q, K, V, sm_scale,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    qvk_offset = off_hz * stride_qh
    Q_block_ptr = tl.make_block_ptr(
        base=Q + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)
    K_block_ptr = tl.make_block_ptr(
        base=K + qvk_offset,
        shape=(BLOCK_DMODEL, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(BLOCK_DMODEL, BLOCK_N),
        order=(0, 1)
    V_block_ptr = tl.make_block_ptr(
        base=V + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_vk, stride_vn),
        offsets=(0, 0),
        block_shape=(BLOCK_N, BLOCK_DMODEL),
        order=(1, 0)
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # scale sm_scale by log_2(e) and use
    # 2^x instead of exp in the loop because CSE and LICM
    # don't work as expected with `exp` in the loop
    qk_scale = sm_scale * 1.44269504
    # load q: it will stay in SRAM throughout
    q = tl.load(Q_block_ptr)
    q = (q * qk_scale).to(tl.float16)
    # loop over k, v and update accumulator
    lo = 0
    hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
    for start_n in range(lo, hi, BLOCK_N):
        # -- load k, v --
        k = tl.load(K_block_ptr)
        v = tl.load(V_block_ptr)
        # -- compute qk ---
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        if IS_CAUSAL:
            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        qk +=, k)
        # -- compute scaling constant ---
        m_i_new = tl.maximum(m_i, tl.max(qk, 1))
        alpha = tl.math.exp2(m_i - m_i_new)
        p = tl.math.exp2(qk - m_i_new[:, None])
        # -- scale and update acc --
        acc_scale = l_i * 0 + alpha  # workaround some compiler bug
        acc *= acc_scale[:, None]
        acc +=, v)
        # -- update m_i and l_i --
        l_i = l_i * alpha + tl.sum(p, 1)
        m_i = m_i_new
        # update pointers
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
    # write back l and m
    acc = acc / l_i[:, None]
    l_ptrs = L + off_hz * N_CTX + offs_m, m_i + tl.math.log2(l_i))
    # write back O
    O_block_ptr = tl.make_block_ptr(
        base=Out + qvk_offset,
        shape=(N_CTX, BLOCK_DMODEL),
        strides=(stride_om, stride_on),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, BLOCK_DMODEL),
        order=(1, 0)

def _bwd_preprocess(
    Out, DO,
    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_n = tl.arange(0, D_HEAD)
    # load
    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
    # compute
    delta = tl.sum(o * do, axis=1)
    # write-back + off_m, delta)

def _bwd_kernel(
    Q, K, V, sm_scale, Out, DO,
    DQ, DK, DV,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
    CAUSAL: tl.constexpr,
    off_hz = tl.program_id(0)
    off_z = off_hz // H
    off_h = off_hz % H
    qk_scale = sm_scale * 1.44269504
    # offset pointers for batch/head
    Q += off_z * stride_qz + off_h * stride_qh
    K += off_z * stride_qz + off_h * stride_qh
    V += off_z * stride_qz + off_h * stride_qh
    DO += off_z * stride_qz + off_h * stride_qh
    DQ += off_z * stride_qz + off_h * stride_qh
    DK += off_z * stride_qz + off_h * stride_qh
    DV += off_z * stride_qz + off_h * stride_qh
    for start_n in range(0, num_block):
        if CAUSAL:
            lo = start_n * BLOCK_M
            lo = 0
        # initialize row/col offsets
        offs_qm = lo + tl.arange(0, BLOCK_M)
        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
        offs_m = tl.arange(0, BLOCK_N)
        offs_k = tl.arange(0, BLOCK_DMODEL)
        # initialize pointers to value-like data
        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
        v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        # pointer to row-wise quantities in value-like data
        D_ptrs = D + off_hz * N_CTX
        l_ptrs = L + off_hz * N_CTX
        # initialize dv amd dk
        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
        # k and v stay in SRAM throughout
        k = tl.load(k_ptrs)
        v = tl.load(v_ptrs)
        # loop over rows
        for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
            offs_m_curr = start_m + offs_m
            # load q, k, v, do on-chip
            q = tl.load(q_ptrs)
            # recompute p = softmax(qk, dim=-1).T
            if CAUSAL:
                qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
                qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk +=, tl.trans(k))
            qk *= qk_scale
            l_i = tl.load(l_ptrs + offs_m_curr)
            p = tl.math.exp2(qk - l_i[:, None])
            # compute dv
            do = tl.load(do_ptrs)
            dv +=, do)
            # compute dp = dot(v, do)
            Di = tl.load(D_ptrs + offs_m_curr)
            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
            dp +=, tl.trans(v))
            # compute ds = p * (dp - delta[:, None])
            ds = p * dp * sm_scale
            # compute dk = dot(ds.T, q)
            dk +=, q)
            # compute dq
            dq = tl.load(dq_ptrs)
            dq +=, k)
  , dq)
            # increment pointers
            dq_ptrs += BLOCK_M * stride_qm
            q_ptrs += BLOCK_M * stride_qm
            do_ptrs += BLOCK_M * stride_qm
        # write-back
        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk), dv), dk)

empty = torch.empty(128, device="cuda")

class _mgqa_attention(torch.autograd.Function):

    def forward(
        # shape constraints
        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
        assert Lq == Lk and Lk == Lv
        assert Lk in {16, 32, 64, 128}
        o = torch.empty_like(q)
        BLOCK_M = 128
        BLOCK_N = 64
        grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

        num_warps = 4 if Lk <= 64 else 8
        #divide query heads into G groups
        q_groups = torch.chunk(q, num_groups, dim=1)
        k_groups = torch.chunk(k, num_groups, dim=1)
        v_groups = torch.chunk(v, num_groups, dim=1)

        for i in range(num_groups):    
                q_groups[i], k_groups[i], v_groups[i], sm_scale,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                o.stride(0), o.stride(1), o.stride(2), o.stride(3),
                q.shape[0], q.shape[1], q.shape[2],

        ctx.save_for_backward(q, k, v, o, L)
        ctx.grid = grid
        ctx.sm_scale = sm_scale
        ctx.BLOCK_DMODEL = Lk
        ctx.causal = causal
        return o

    def backward(ctx, do, num_groups):
        BLOCK = 128

        q, k, v, o, L = ctx.saved_tensors

        do = do.contiguous()

        dq = torch.zeros_like(q, dtype=torch.float32)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        delta = torch.empty_like(L)

        #divide query heads into G groups
        q_groups = torch.chunk(q, num_groups, dim=1)
        k_groups = torch.chunk(k, num_groups, dim=1)
        v_groups = torch.chunk(v, num_groups, dim=1)
        for i in range(num_groups):
            _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
                o, do,
                BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
                q_groups[i], k_groups[i], v_groups[i], ctx.sm_scale,
                o, do,
                dq, dk, dv,
                L, delta,
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),
                k.stride(0), k.stride(1), k.stride(2), k.stride(3),
                v.stride(0), v.stride(1), v.stride(2), v.stride(3),
                q.shape[0], q.shape[1], q.shape[2],
                BLOCK_M=BLOCK, BLOCK_N=BLOCK,
                BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
        return dq, dk, dv, None, None

attention = _mgqa_attention.apply

# Initialize random inputs
q = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]
k = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]
v = torch.randn(10, 8, 16, 64)  # [batch_size, num_heads, seq_length, head_dim]

# Set other parameters
causal = False
sm_scale = 0.1
num_groups = 4  # Number of groups to divide the query heads into

# Apply the attention
output = attention(q, k, v, causal, sm_scale, num_groups)





Author = {Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebrón and Sumit Sanghai},
Title = {GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints},
Year = {2023},
Eprint = {arXiv:2305.13245},

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mgqa-0.0.5.tar.gz (27.1 kB view details)

Uploaded Source

Built Distribution

mgqa-0.0.5-py3-none-any.whl (24.8 kB view details)

Uploaded Python 3

File details

Details for the file mgqa-0.0.5.tar.gz.

File metadata

  • Download URL: mgqa-0.0.5.tar.gz
  • Upload date:
  • Size: 27.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mgqa-0.0.5.tar.gz
Algorithm Hash digest
SHA256 9acbb77f472724228dfe96288af7101e25c5d90fecc6ed9d417328f766dbc8a7
MD5 c65e32fd82e7346a3228be11b0f03964
BLAKE2b-256 ea0f29785c77e14bd1b9cc2bf03b5a1ed85cd6066379fee9566bf5c982fd0012

See more details on using hashes here.

File details

Details for the file mgqa-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: mgqa-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 24.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mgqa-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 6ed6930710caa7a85408508913cb0529bfd19dc663c21ea20a3c6d8b7fa77da6
MD5 fcdbe509e14160c6a90cb211bc6caf71
BLAKE2b-256 7e0ad5c40ce8c9e6ecba5dc16e24d7ba8f48e63b64f3b42a16f6ef4f6bd8fc55

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page