Infini-Transformer in Pytorch
Project description
Infini-Transformer - Pytorch
Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.
Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.
Install
$ pip install infini-transformer-pytorch
Usage
import torch
from infini_transformer_pytorch import InfiniTransformer
transformer = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128, # high head dimension may be part of the reason they got good results (kv has high capacity)
heads = 8,
use_mem_delta_rule = True
)
x = torch.randint(0, 256, (1, 1024))
logits1, _, mem1 = transformer(x, return_new_memories = False)
logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)
Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper
import torch
from infini_transformer_pytorch import (
InfiniTransformer,
InfiniTransformerWrapper
)
# model and wrapper
model = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128,
heads = 8,
use_mem_delta_rule = True
)
wrapper = InfiniTransformerWrapper(
model,
segment_length = 512,
detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()
# mock input
seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence
# training
loss = wrapper(
seq,
backward = True # will automatically segment and accumulate gradients when it detaches the memories
)
# after much data...
# calculating eval loss
with torch.no_grad():
wrapper.eval()
eval_loss = wrapper(seq)
# generating is as easy as
output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])
output.shape # (2, 8192 - 1)
Testing
Train an autoregressive enwik8
$ python train.py
Todo
-
detach_mems_every_num_segments
hyperparameter is too confusing, get rid of it - experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
- working example with enwik8
Citations
@inproceedings{Munkhdalai2024LeaveNC,
title = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
author = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:269033427}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file infini_transformer_pytorch-0.1.6.tar.gz
.
File metadata
- Download URL: infini_transformer_pytorch-0.1.6.tar.gz
- Upload date:
- Size: 36.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | b962ab5a1d870b7d78fe516e226a2dd93d87148768b9d0dde6e9bf861dea84c5 |
|
MD5 | eb6c6c8243f1bb39fe8a29be9cf477da |
|
BLAKE2b-256 | 229f30c8f7131108de55d98a9ff968df6b2270df47d09626a10f3d5ada9d40f4 |
File details
Details for the file infini_transformer_pytorch-0.1.6-py3-none-any.whl
.
File metadata
- Download URL: infini_transformer_pytorch-0.1.6-py3-none-any.whl
- Upload date:
- Size: 9.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 492cfa0d6b558e37f455ff8e3b32938ca31eb0666989ca11f3120b4654bbe0fa |
|
MD5 | c4404e359f69a4bc49100a071f9ef298 |
|
BLAKE2b-256 | 8a25f858277ecbe4a9d19a98eb586b620e9166572a857fb94c5f6a234b5cda75 |