Skip to main content

Infini-Transformer in Pytorch

Project description

Infini-Transformer - Pytorch

Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.

Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.

Yannic Kilcher's explanation

Install

$ pip install infini-transformer-pytorch

Usage

import torch
from infini_transformer_pytorch import InfiniTransformer

transformer = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,  # high head dimension may be part of the reason they got good results (kv has high capacity)
    heads = 8,
    use_mem_delta_rule = True
)

x = torch.randint(0, 256, (1, 1024))

logits1, _, mem1 = transformer(x, return_new_memories = False)
logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)

Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper

import torch

from infini_transformer_pytorch import (
    InfiniTransformer,
    InfiniTransformerWrapper
)

# model and wrapper

model = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,
    heads = 8,
    use_mem_delta_rule = True
)

wrapper = InfiniTransformerWrapper(
    model,
    segment_length = 512,
    detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()

# mock input

seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence

# training

loss = wrapper(
    seq,
    backward = True # will automatically segment and accumulate gradients when it detaches the memories
)

# after much data...

# calculating eval loss

with torch.no_grad():
    wrapper.eval()
    eval_loss = wrapper(seq)

# generating is as easy as

output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])

output.shape # (2, 8192 - 1)

Testing

Train an autoregressive enwik8

$ python train.py

Todo

  • detach_mems_every_num_segments hyperparameter is too confusing, get rid of it
  • experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
  • working example with enwik8

Citations

@inproceedings{Munkhdalai2024LeaveNC,
    title   = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
    author  = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:269033427}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

infini_transformer_pytorch-0.1.6.tar.gz (36.7 MB view details)

Uploaded Source

Built Distribution

File details

Details for the file infini_transformer_pytorch-0.1.6.tar.gz.

File metadata

File hashes

Hashes for infini_transformer_pytorch-0.1.6.tar.gz
Algorithm Hash digest
SHA256 b962ab5a1d870b7d78fe516e226a2dd93d87148768b9d0dde6e9bf861dea84c5
MD5 eb6c6c8243f1bb39fe8a29be9cf477da
BLAKE2b-256 229f30c8f7131108de55d98a9ff968df6b2270df47d09626a10f3d5ada9d40f4

See more details on using hashes here.

File details

Details for the file infini_transformer_pytorch-0.1.6-py3-none-any.whl.

File metadata

File hashes

Hashes for infini_transformer_pytorch-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 492cfa0d6b558e37f455ff8e3b32938ca31eb0666989ca11f3120b4654bbe0fa
MD5 c4404e359f69a4bc49100a071f9ef298
BLAKE2b-256 8a25f858277ecbe4a9d19a98eb586b620e9166572a857fb94c5f6a234b5cda75

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page