Simplex random feature attention in PyTorch for both training and inference
Project description
srf-attention
Simplex Random Feature attention, in PyTorch
A Prelude
Why?
Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't :wave: technically :wave: stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference! Smarter men than I proceeded to publish this, this, and that. This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. It's intended to be used for ERPTRI, but do with it what you will.
What is this good for?
Well, it really ain't for you open-sourcerers. You're bottlenecked by weight I/O. But for those running large-batch inference, e.g as part of a data pipeline, KV cache I/O is the limiter for sequences > ~700 tokens. ERPTRI efficiently [sic] drops the KV cache size of any pretrained auto-regressive Transformer from $O(LD)
$ to $O(D^2)
$. This repo implements the PyTorch modules necessary for the fine-tuning phase of ERPTRI, and for efficient inference.
Next steps
Venture forth and conquer.
Installation
pip install git+https://github.com/alexjlevenston/srf-attention
Usage
import torch
from srf_attention import FastAttention, simplex_random_matrix
device = 'cpu'
B, H, L, D = (1, 8, 1024, 128)
# Generate some simplex random features
srfs = simplex_random_matrix(nb_rows = D, nb_features = D, normalize = False, device = device)
# Or just use the FastAttention module
attn = FastAttention(head_dim = D, nb_features = D, causal = True).to(device)
# For training, automatically redraw features for each forward pass
# False by default
attn.redraw_on_call_(True)
# Placeholder queries, keys, and values:
q, k, v = [torch.empty(B, H, L, D) for _ in range(3)]
# For training, naive torch:
o = attn(q=q, k=k, v=v, mode='train', attn_fn='torch')
# For training, w/ flash-attention-no-softmax:
o = attn(q=q, k=k, v=v, mode='train', attn_fn='flash')
# For inference, disable auto-redraw:
attn.redraw_on_call_(False)
# For inference, prefill, parallel:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='parallel')
# For inference, prefill, chunkwise-recurrent:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='chunkwise-recurrent')
# For inference, prefill, recurrent:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='recurrent')
# For inference, generation:
q = torch.empty(B, H, 1, D)
denominator = torch.empty(B, H, 1, D)
o, kv, key_maxima, denominator = (q=q, kv=kv, key_maxima=key_maxima, denominator=denominator, mode='generation')
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for srf_attention-1.0.8-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ef13c806b6bfb25741ace2a028ec0ed24eb57b15811ed5c441e47f90b7ffe370 |
|
MD5 | eae27871d70152d194c85d4172c2168e |
|
BLAKE2b-256 | fb25d94a4ed8b2ce0e28093158726eee77e74b5355ec3ee0317fde295dede983 |