Skip to main content

Simplex random feature attention in PyTorch for both training and inference

Project description

srf-attention

Simplex Random Feature attention, in PyTorch

A Prelude

Why? What? Huh?

Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference! Smarter men than I proceeded to publish this, this, and that. This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. It's intended to be used for ERPTRI, but do with it what you will.

What is this good for?

Well, it really ain't for you open-sourcerers. You're bottlenecked by weight I/O. But for those running large-batch inference, e.g as part of a synthetic data pipeline, KV cache I/O dominates the cost for sequences > ~700 tokens. ERPTRI efficiently [sic] drops the KV cache size of any pretrained auto-regressive Transformer from $O(LD)$ to $O(D^2)$. This repo implements the PyTorch modules necessary for the fine-tuning phase of ERPTRI, and for efficient inference.

Next steps

Venture forth and conquer. But first, fine-tune under an ordinary NLL loss on the original pretraining distribution, after performing the appropriate model surgery. Here's the RedPajama subset that was used for the Llama 2 retrofit.

Installation

Insta-wheel:

pip install git+https://github.com/alexjlevenston/srf-attention

Usage

import torch
from srf_attention import Attention

device = 'cpu'

B, H, L, D = (1, 8, 1024, 128)

q, k, v = [torch.randn(B, H, L, D).requires_grad_() for _ in range(3)]

# CHUNK_SIZE controls the memory/compute tradeoff of the attention computation
# Controls memory/compute tradeoff
CHUNK_SIZE=1024

# Simplex Random Feature (SRF) Attention module
# All intermediate computations done in FP32, but cached values are FP16.
# Recomputes the attention matrix in the backward pass instead of storing it:
attn = Attention(d=D, n_features=D, causal=True, device=device)

# During fine-tuning, replace your softmax attention function with this:
o = attn(q, k, v, mode='train', attn_fn='torch', chunk_size=CHUNK_SIZE)

# Use 1 instance for each layer,
# and disable auto-redraw prior to beginning training:
attn.redraw_on_call_(False)

# On each training step, call redraw_() to resample the random features:
attn.redraw_()

# That's it! Now just fine-tune.

Example

Here's an example, using the HF Transformers diff I wrote to retrofit Llama with SRF attention:

# Make sure TILE_SIZE env var is set, I use TILE_SIZE=256
import torch
# install using `pip install git+https://github.com/alexjlevenston/transformers-llama-srf`
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')

for module in model.modules():
  if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
    module.use_fast_attn_(True)
    module.attn_fn.redraw_on_call_(False)

def resample_rfs(model):
  for module in model.modules():
    if isinstance(module, transformers.models.llama.modeling_llama.LlamaAttention):
      module.attn_fn.redraw_(next(model.parameters()).device)

optimizer = YourOptimizerHere()

for step, batch in enumerate(imaginary_dataset):
  inputs, targets = batch
  # Always resample random features manually,
  # because auto-resampling causes issues with checkpointing
  resample_rfs(model)
  outputs = model(inputs)
  logits = outputs.logits.reshape(-1, outputs.logits.shape[-1])
  loss = torch.nn.functional.cross_entropy(logits, targets['input_ids'].reshape(-1))
  loss.backward()
  optimizer.step()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

srf-attention-1.0.16.tar.gz (8.4 kB view hashes)

Uploaded Source

Built Distribution

srf_attention-1.0.16-py3-none-any.whl (7.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page