Infini-Transformer in Pytorch
Project description
Infini-Transformer - Pytorch
Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.
Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.
Install
$ pip install infini-transformer-pytorch
Usage
import torch
from infini_transformer_pytorch import InfiniTransformer
transformer = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128, # high head dimension may be part of the reason they got good results (kv has high capacity)
heads = 8,
rotary_emb_linear_attn = True
)
x = torch.randint(0, 256, (1, 1024))
logits1, cached_kv1, mem1 = transformer(x, return_new_memories = False)
logits2, cached_kv2, mem2 = transformer(x, past_memories = mem1, cached_kv = cached_kv1, return_new_memories = False)
logits3, cached_kv3, mem3 = transformer(x, past_memories = mem2, cached_kv = cached_kv2, return_new_memories = True)
Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper
import torch
from infini_transformer_pytorch import (
InfiniTransformer,
InfiniTransformerWrapper
)
# model and wrapper
model = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128,
heads = 8,
rotary_emb_linear_attn = True
)
wrapper = InfiniTransformerWrapper(
model,
segment_length = 512,
detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()
# mock input
seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence
# training
loss = wrapper(
seq,
backward = True # will automatically segment and accumulate gradients when it detaches the memories
)
# after much data...
# calculating eval loss
with torch.no_grad():
wrapper.eval()
eval_loss = wrapper(seq)
# generating is as easy as
output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])
output.shape # (2, 8192 - 1)
Testing
Train an autoregressive enwik8
$ python train.py
Todo
- working example with enwik8
Citations
@inproceedings{Munkhdalai2024LeaveNC,
title = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
author = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:269033427}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Close
Hashes for infini_transformer_pytorch-0.1.1.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | 331b099db4ed1a7ae507cb66a48f02d54a1d255f97a85b6e99dbcfdac651fbdd |
|
MD5 | 4ce257c0e36b2e53f92e817a0edf9968 |
|
BLAKE2b-256 | 23141e94866cd2b92fa5993b1faf4aa1a7629c644241050ef85f6077f3fcc148 |
Close
Hashes for infini_transformer_pytorch-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fdf4262e70b768476b156df52112fd6997c25e45a74f6e3ff8a80d752d6b984b |
|
MD5 | 652bcaeef119591775cd7b0002c3d81f |
|
BLAKE2b-256 | 671dd2c2d84aaaeef453c49a20278e36a86e0640535613e450b55ec356710f27 |