grouped-query-attention-pytorch
Project description
grouped-query-attention-pytorch
(Unofficial) PyTorch implementation of grouped-query attention (GQA) from GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Includes:
- scaled dot-product attention with GQA support. (See: scaled_dot_product_gqa usage)
- GQA multi-head attention layer. (See: MultiheadGQA usage)
- Code to convert pretrained T5 model to use GQA. (See: T5 usage )
- Prototype (untrained) GQA encoder-decoder models:
GQATransformer
,GQATransformerLM
(See: GQATransformer )usage) - Reproduce runtime benchmarks from GQA paper, figure 6 (See: scripts/)README.md)
To do:
- Fine-tuning code for T5 GQA models
- Reproduce fine-tuning results from GQA paper, figures 3,5
Install
PyPI: (NOT YET AVAILABLE)
pip install grouped-query-attention-pytorch
From source:
pip install "grouped-query-attention-pytorch @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"
For contributors:
# Install all dev dependencies (tests, T5 support, etc.)
pip install "grouped-query-attention-pytorch[test,t5] @ git+ssh://git@github.com/fkodom/grouped-query-attention-pytorch.git"
# Setup pre-commit hooks
pre-commit install
Benchmark
I attempt to reproduce the runtime benchmarks from the GQA paper (Figure 6). Unfortunately, I don't have access to the same hardware, so the comparison isn't perfect. (They use multiple high-end GPUs, and I use a single 2080 Ti.) Even with different hardware, though, it is clear that runtime scales similarly with the number of GQA groups.
For more details, see scripts/README.md
Left: This repo
Right: Original paper
Usage
scaled_dot_product_gqa
See: attention.py
Intended to be a drop-in replacement for F.scaled_dot_product_attention
with support for GQA.
NOTE: The built-in
F.scaled_dot_product_attention
will be much faster when you're not using grouped queries -- especially fortorch>=2.0
, which uses flash attention under the hood. However, this benchmark shows that naiescaled_dot_product_gqa
is faster than flash attention when the number of GQA groups is small. 🔥
import torch
from grouped_query_attention_pytorch.attention import scaled_dot_product_gqa
# shapes: (batch_size, seq_len, num_heads, head_dim)
query = torch.randn(1, 256, 8, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 2, 64, device="cuda", dtype=torch.float16)
out, attn_weights = scaled_dot_product_gqa(
query,
key,
value,
is_causal=True, # default: False
need_weights=True, # default: False, which returns 'attn_weights=None'
)
print(out.shape) # (batch_size, q_seq_len, kv_heads, embed_dim)
# torch.Size([1, 256, 2, 64])
print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])
MultiheadGQA
See: attention.py
Intended to be a drop-in replacement for nn.MultiheadAttention
with support for GQA.
NOTE: The same performance advice from scaled_dot_product_gqa (above) applies here as well.
from grouped_query_attention_pytorch.attention import MultiheadGQA
mha = MultiheadGQA(
embed_dim=512, query_heads=8, kv_heads=2, device="cuda", dtype=torch.float16
)
# shapes: (batch_size, seq_len, embed_dim)
query = torch.randn(1, 256, 512, device="cuda", dtype=torch.float16)
key = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
value = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
out, attn_weights = mha(
query,
key,
value,
is_causal=True, # default: False
need_weights=True, # default: False, which returns 'attn_weights=None'
)
print(out.shape) # (batch_size, q_seq_len, embed_dim)
# torch.Size([1, 256, 512])
print(attn_weights.shape) # (batch_size, q_seq_len, kv_seq_len, kv_heads)
# torch.Size([1, 256, 128, 2])
T5
See: t5.py
Convert a pretrained T5 model from huggingface/transformers to use GQA. The resulting model can be used and trained with the Huggingface Transformers library, just like an ordinary T5 model.
from transformers import T5ForConditionalGeneration, T5Tokenizer
from grouped_query_attention_pytorch.t5 import convert_t5_to_gqa
# Initialize a pre-trained T5 model
t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
# Convert attention layers to GQA
t5_gqa = convert_t5_to_gqa(t5, kv_heads=2, inplace=False) # default: inplace=False
# Generate some text with the converted model
input_ids = tokenizer(
"translate English to German: The house is wonderful.", return_tensors="pt"
).input_ids
outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
text = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(text)
# The correct answer is: ['<pad>', 'Das', 'Haus', 'ist', 'wunderbar', '.', '</s>']
# NOTE: The original T5 model produces this answer, and so does GQA when we use the
# maximum number of KV heads (kv_heads=8 in this example), which effectively makes
# GQA equivalent to the original T5 model with MHA. The text quickly degrades as
# we reduce the number of heads.
GQATransformer
I also provide a prototype implementation of an (untrained) encoder-decoder Transformer model, which uses GQA instead of MHA. This is mostly for reference/educational purposes, but in principle it could be used as a drop-in replacement for nn.Transformer
.
See: transformer.py
from grouped_query_attention_pytorch.transformer import GQATransformer, GQATransformerLM
device = torch.device("cuda")
dtype = torch.float16
net = GQATransformer(
d_model=512, # required
nhead=8, # required
kv_heads=2, # required
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len, d_model)
x = torch.randn(1, 256, 512, device=device, dtype=dtype)
with torch.no_grad():
y = net.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, 512])
num_tokens = 10000 # usually obtained from the tokenizer
lm = GQATransformerLM(
num_tokens=num_tokens, # required
d_model=512, # required
nhead=8, # required
kv_heads=2, # required
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-5,
device=device,
dtype=dtype,
)
# shape: (batch_size, seq_len)
x = torch.randint(0, num_tokens, (1, 256), device=device, dtype=torch.long)
with torch.no_grad():
y = lm.forward(x, is_causal=True) # default: is_causal=True
print(y.shape)
# torch.Size([1, 256, num_tokens])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file grouped_query_attention_pytorch-0.3.0.tar.gz
.
File metadata
- Download URL: grouped_query_attention_pytorch-0.3.0.tar.gz
- Upload date:
- Size: 531.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5820ea1ea3c8e63b47f57b7900747cea1461f84041aa57655801bdc366e50871 |
|
MD5 | 9926c9b2fd3815030eebd69ea69c8eb4 |
|
BLAKE2b-256 | 7c1da4b0c46ae2f8c33db6eb6d5e5ab52cf097a1bd0d42e8be2349d4c370e770 |
File details
Details for the file grouped_query_attention_pytorch-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: grouped_query_attention_pytorch-0.3.0-py3-none-any.whl
- Upload date:
- Size: 526.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.12.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d75cacffeab3dded0bc7ab6bb60e4902e8db7247f70c66e0a4d27fafc2a67331 |
|
MD5 | 0466f64748860f744eca8f5ea6bf41b8 |
|
BLAKE2b-256 | 100c5df79ac5ffa40e1865b9c1c013ad05d499b7d65436ebec9d2e675c071222 |