Simplex random feature attention in PyTorch for both training and inference
Project description
srf-attention
Simplex Random Feature attention, in PyTorch
A Prelude
Why? What? Huh?
Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference! Smarter men than I proceeded to publish this, this, and that. This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. It's intended to be used for ERPTRI, but do with it what you will.
What is this good for?
Well, it really ain't for you open-sourcerers. You're bottlenecked by weight I/O. But for those running large-batch inference, e.g as part of a synthetic data pipeline, KV cache I/O dominates the cost for sequences > ~700 tokens. ERPTRI efficiently [sic] drops the KV cache size of any pretrained auto-regressive Transformer from $O(LD)
$ to $O(D^2)
$. This repo implements the PyTorch modules necessary for the fine-tuning phase of ERPTRI, and for efficient inference.
Next steps
Venture forth and conquer. But first, fine-tune under an ordinary NLL loss on the original pretraining distribution, after performing the appropriate model surgery. Here's the RedPajama subset that was used for the Llama 2 retrofit.
Installation
Insta-wheel:
pip install git+https://github.com/alexjlevenston/srf-attention
Usage
import torch
from srf_attention import Attention
device = 'cpu'
B, H, L, D = (1, 8, 1024, 128)
q, k, v = [torch.randn(B, H, L, D).requires_grad_() for _ in range(3)]
# CHUNK_SIZE controls the memory/compute tradeoff of the attention computation
# Controls memory/compute tradeoff
CHUNK_SIZE=1024
# Simplex Random Feature (SRF) Attention module
# All intermediate computations done in FP32, but cached values are FP16.
# Recomputes the attention matrix in the backward pass instead of storing it:
attn = Attention(d=D, n_features=D, causal=True, device=device)
# During fine-tuning, replace your softmax attention function with this:
o = attn(q, k, v, mode='train', attn_fn='torch', chunk_size=CHUNK_SIZE)
# Use 1 instance for each layer,
# and disable auto-redraw prior to beginning training:
attn.redraw_on_call_(False)
# On each training step, call redraw_() to resample the random features:
attn.redraw_()
# That's it! Now just fine-tune.
Example
Here's an example, using the HF Transformers diff I wrote to retrofit Llama with SRF attention:
# Make sure TILE_SIZE env var is set, I use TILE_SIZE=256
import torch
# install using `pip install git+https://github.com/alexjlevenston/transformers-llama-srf`
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
for module in model.modules():
if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
module.use_fast_attn_(True)
module.attn_fn.redraw_on_call_(False)
def resample_rfs(model):
for module in model.modules():
if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
module.attn_fn.redraw_(next(model.parameters()).device)
optimizer = YourOptimizerHere()
for step, batch in enumerate(imaginary_dataset):
inputs, targets = batch
# Always resample random features manually,
# because auto-resampling causes issues with checkpointing
resample_rfs(model)
outputs = model(inputs)
logits = outputs.logits.reshape(-1, outputs.logits.shape[-1])
loss = torch.nn.functional.cross_entropy(logits, targets['input_ids'].reshape(-1))
loss.backward()
optimizer.step()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for srf_attention-1.0.16-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2a869d7d5b0dc1698b6ec4a3aff32e982d357f4741d432de94a1995b11bb500d |
|
MD5 | 80a7f83d6275033bdb46fb65d461a612 |
|
BLAKE2b-256 | 4976dca37b93522a712b8db314f31f746a5d93417f1349b6f4284366c8b8dcd9 |