Skip to main content

grouped-query-attention-pytorch

Project description

grouped-query-attention-pytorch

(Unofficial) PyTorch implementation of grouped-query attention (GQA) from GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

compare-attention-mechanisms

Includes:

To do:

  • Fine-tuning code for T5 GQA models
  • Reproduce fine-tuning results from GQA paper, figures 3,5

Install

PyPI: (NOT YET AVAILABLE)

pip install grouped-query-attention-pytorch

From source:

pip install "grouped-query-attention-pytorch @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"

For contributors:

# Install all dev dependencies (tests, T5 support, etc.)
pip install "grouped-query-attention-pytorch[test,t5] @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install

Benchmark

I attempt to reproduce the runtime benchmarks from the GQA paper (Figure 6). Unfortunately, I don't have access to the same hardware, so the comparison isn't perfect. (They use multiple high-end GPUs, and I use a single 2080 Ti.) Even with different hardware, though, it is clear that runtime scales similarly with the number of GQA groups.

For more details, see scripts/README.md

Left: This repo
Right: Original paper

drawing drawing

Usage

scaled_dot_product_gqa

See: attention.py

Intended to be a drop-in replacement for F.scaled_dot_product_attention with support for GQA.

NOTE: The built-in F.scaled_dot_product_attention will be much faster when you're not using grouped queries -- especially for torch>=2.0, which uses flash attention under the hood. However, this benchmark shows that naie scaled_dot_product_gqa is faster than flash attention when the number of GQA groups is small. 🔥

import torch

from grouped_query_attention_pytorch.attention import scaled_dot_product_gqa

# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)

out, attn_weights = scaled_dot_product_gqa(
    query,
    key,
    value,
    is_causal=True,  # default: False
    need_weights=True,  # default: False, which returns 'attn_weights=None'
)
print(out.shape)  # (batch_size, q_seq_len, kv_heads, embed_dim)
# torch.Size([1, 256, 2, 64])
print(attn_weights.shape)  # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])

MultiheadGQA

See: attention.py

Intended to be a drop-in replacement for nn.MultiheadAttention with support for GQA.

NOTE: The same performance advice from scaled_dot_product_gqa (above) applies here as well.

from grouped_query_attention_pytorch.attention import MultiheadGQA

mha = MultiheadGQA(
    embed_dim=512, query_heads=8, kv_heads=2, device="cuda", dtype=torch.float16
)

# shapes: (batch_size, seq_len, embed_dim)
query = torch.randn(1, 256, 512, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)

out, attn_weights = mha(
    query,
    key,
    value,
    is_causal=True, # default: False
    need_weights=True, # default: False, which returns 'attn_weights=None'
)
print(out.shape)  # (batch_size, q_seq_len, embed_dim)
# torch.Size([1, 256, 512])
print(attn_weights.shape)  # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])

T5

See: t5.py

Convert a pretrained T5 model from huggingface/transformers to use GQA. The resulting model can be used and trained with the Huggingface Transformers library, just like an ordinary T5 model.

from transformers import T5ForConditionalGeneration, T5Tokenizer

from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa

# Initialize a pre-trained T5 model
t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
# Convert attention layers to GQA
t5_gqa = convert_t5_to_gqa(t5, kv_heads=2, inplace=False)  # default: inplace=False

# Generate some text with the converted model
input_ids = tokenizer(
    "translate English to German: The house is wonderful.", return_tensors="pt"
).input_ids
outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
text = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(text)
# The correct answer is:  ['<pad>', 'Das', 'Haus', 'ist', 'wunderbar', '.', '</s>']
# NOTE: The original T5 model produces this answer, and so does GQA when we use the
# maximum number of KV heads (kv_heads=8 in this example), which effectively makes
# GQA equivalent to the original T5 model with MHA.  The text quickly degrades as
# we reduce the number of heads.

GQATransformer

I also provide a prototype implementation of an (untrained) encoder-decoder Transformer model, which uses GQA instead of MHA. This is mostly for reference/educational purposes, but in principle it could be used as a drop-in replacement for nn.Transformer.

See: transformer.py

from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM

device = torch.device("cuda")
dtype = torch.float16

net = GQATransformer(
    d_model=512,  # required
    nhead=8,  # required
    kv_heads=2,  # required
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 256, 512, device=device, dtype=dtype)
with torch.no_grad():
    y = net.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, 512])

num_tokens = 10000  # usually obtained from the tokenizer
lm = GQATransformerLM(
    num_tokens=num_tokens,  # required
    d_model=512,  # required
    nhead=8,  # required
    kv_heads=2,  # required
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    activation="relu",
    layer_norm_eps=1e-5,
    device=device,
    dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 256), device=device, dtype=torch.long)
with torch.no_grad():
    y = lm.forward(x, is_causal=True)  # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, num_tokens])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grouped_query_attention_pytorch-0.3.0.tar.gz (531.8 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file grouped_query_attention_pytorch-0.3.0.tar.gz.

File metadata

File hashes

Hashes for grouped_query_attention_pytorch-0.3.0.tar.gz
Algorithm Hash digest
SHA256 5820ea1ea3c8e63b47f57b7900747cea1461f84041aa57655801bdc366e50871
MD5 9926c9b2fd3815030eebd69ea69c8eb4
BLAKE2b-256 7c1da4b0c46ae2f8c33db6eb6d5e5ab52cf097a1bd0d42e8be2349d4c370e770

See more details on using hashes here.

File details

Details for the file grouped_query_attention_pytorch-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for grouped_query_attention_pytorch-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d75cacffeab3dded0bc7ab6bb60e4902e8db7247f70c66e0a4d27fafc2a67331
MD5 0466f64748860f744eca8f5ea6bf41b8
BLAKE2b-256 100c5df79ac5ffa40e1865b9c1c013ad05d499b7d65436ebec9d2e675c071222

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page