Infini-Transformer in Pytorch
Project description
Infini-Transformer - Pytorch
Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.
Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.
Install
$ pip install infini-transformer-pytorch
Usage
import torch
from infini_transformer_pytorch import InfiniTransformer
transformer = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128, # high head dimension may be part of the reason they got good results (kv has high capacity)
heads = 8,
use_mem_delta_rule = True
)
x = torch.randint(0, 256, (1, 1024))
logits1, _, mem1 = transformer(x, return_new_memories = False)
logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)
Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper
import torch
from infini_transformer_pytorch import (
InfiniTransformer,
InfiniTransformerWrapper
)
# model and wrapper
model = InfiniTransformer(
num_tokens = 256,
dim = 512,
depth = 8,
dim_head = 128,
heads = 8,
use_mem_delta_rule = True
)
wrapper = InfiniTransformerWrapper(
model,
segment_length = 512,
detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()
# mock input
seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence
# training
loss = wrapper(
seq,
backward = True # will automatically segment and accumulate gradients when it detaches the memories
)
# after much data...
# calculating eval loss
with torch.no_grad():
wrapper.eval()
eval_loss = wrapper(seq)
# generating is as easy as
output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])
output.shape # (2, 8192 - 1)
Testing
Train an autoregressive enwik8
$ python train.py
Todo
-
detach_mems_every_num_segmentshyperparameter is too confusing, get rid of it - experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
- working example with enwik8
Citations
@inproceedings{Munkhdalai2024LeaveNC,
title = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
author = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:269033427}
}
@article{Yang2024ParallelizingLT,
title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length},
author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.06484},
url = {https://api.semanticscholar.org/CorpusID:270371554}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file infini_transformer_pytorch-0.2.1.tar.gz.
File metadata
- Download URL: infini_transformer_pytorch-0.2.1.tar.gz
- Upload date:
- Size: 36.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba31c50080255fbe6e724b399567d8c0a79b2956bf1573558e7075bdf00cc87a
|
|
| MD5 |
07724d17b27cc6e877eaa88797e6e415
|
|
| BLAKE2b-256 |
ca3b41678f7924fee44bc031232adda900c0ea53a381bd6ae496f451edc0f7b5
|
File details
Details for the file infini_transformer_pytorch-0.2.1-py3-none-any.whl.
File metadata
- Download URL: infini_transformer_pytorch-0.2.1-py3-none-any.whl
- Upload date:
- Size: 10.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.9.21
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3c3f49a177042713bd630361855eca9d71d6949a14bb3fcec8ad72e857974b5c
|
|
| MD5 |
0b554c5dc4d8ab6734505e5d8936af66
|
|
| BLAKE2b-256 |
e8fbc1dfd817e5e221cf2406401ed5fc0ea78b8351aded74b02fa991a4fb4302
|