Skip to main content

Simplex random feature attention in PyTorch for both training and inference

Project description

srf-attention

Simplex Random Feature attention, in PyTorch

A Prelude

Why?

Softmax attention ate the world. But now it's eating our wallets. Luckily enough for us wordcels, those nifty shape rotators realized that even though softmax isn't :wave: technically :wave: stationary, it's amenable to Monte Carlo methods. Translation: we can retrofit pretrained LLMs for recurrent inference! Smarter men than I proceeded to publish this, this, and that. This repo is a PyTorch implementation of "that", with some syntactic sugar added to aid digestion. It's intended to be used for ERPTRI, but do with it what you will.

What is this good for?

Well, it really ain't for you open-sourcerers. You're bottlenecked by weight I/O. But for those running large-batch inference, e.g as part of a data pipeline, KV cache I/O is the limiter for sequences > ~700 tokens. ERPTRI efficiently [sic] drops the KV cache size of any pretrained auto-regressive Transformer from $O(LD)$ to $O(D^2)$. This repo implements the PyTorch modules necessary for the fine-tuning phase of ERPTRI, and for efficient inference.

Next steps

Venture forth and conquer.

Installation

pip install git+https://github.com/alexjlevenston/srf-attention

Usage

import torch
from srf_attention import FastAttention, simplex_random_matrix

device = 'cpu'

B, H, L, D = (1, 8, 1024, 128)

# Generate some simplex random features
srfs = simplex_random_matrix(nb_rows = D, nb_features = D, normalize = False, device = device)

# Or just use the FastAttention module
attn = FastAttention(head_dim = D, nb_features = D, causal = True).to(device)
# For training, automatically redraw features for each forward pass
# False by default
attn.redraw_on_call_(True)

# Placeholder queries, keys, and values:
q, k, v = [torch.empty(B, H, L, D) for _ in range(3)]

# For training, naive torch:
o = attn(q=q, k=k, v=v, mode='train', attn_fn='torch')
# For training, w/ flash-attention-no-softmax:
o = attn(q=q, k=k, v=v, mode='train', attn_fn='flash')

# For inference, disable auto-redraw:
attn.redraw_on_call_(False)
# For inference, prefill, parallel:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='parallel')
# For inference, prefill, chunkwise-recurrent:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='chunkwise-recurrent')
# For inference, prefill, recurrent:
o, kv, key_maxima, denominator = (q=q, k=k, v=v, mode='prefill', attn_fn='recurrent')
# For inference, generation:
q = torch.empty(B, H, 1, D)
denominator = torch.empty(B, H, 1, D)
o, kv, key_maxima, denominator = (q=q, kv=kv, key_maxima=key_maxima, denominator=denominator, mode='generation')

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

srf-attention-1.0.8.tar.gz (6.0 kB view hashes)

Uploaded Source

Built Distribution

srf_attention-1.0.8-py3-none-any.whl (6.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page