Skip to main content

Infini-Transformer in Pytorch

Project description

Infini-Transformer - Pytorch

Implementation of Infini-Transformer in Pytorch. They use a linear attention scheme to compress past memories and demonstrate multiple SOTAs for long context benchmarks.

Although unlikely to beat Ring Attention, I think it is worth exploring, as the techniques are orthogonal.

Yannic Kilcher's explanation

Install

$ pip install infini-transformer-pytorch

Usage

import torch
from infini_transformer_pytorch import InfiniTransformer

transformer = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,  # high head dimension may be part of the reason they got good results (kv has high capacity)
    heads = 8,
    use_mem_delta_rule = True
)

x = torch.randint(0, 256, (1, 1024))

logits1, _, mem1 = transformer(x, return_new_memories = False)
logits2, _, mem2 = transformer(x, past_memories = mem1, return_new_memories = False)
logits3, _, mem3 = transformer(x, past_memories = mem2, return_new_memories = True)

Training a transformer with recurrence usually trips up a lot of researchers, so to make it easy, just wrap it with InfiniTransformerWrapper

import torch

from infini_transformer_pytorch import (
    InfiniTransformer,
    InfiniTransformerWrapper
)

# model and wrapper

model = InfiniTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 8,
    dim_head = 128,
    heads = 8,
    use_mem_delta_rule = True
)

wrapper = InfiniTransformerWrapper(
    model,
    segment_length = 512,
    detach_mems_every_num_segments = 2 # greater than 1 so the network can learn how to 'write' to the fast weight memories
).cuda()

# mock input

seq = torch.randint(0, 256, (2, 10000)).cuda() # can be arbitrarily long sequence

# training

loss = wrapper(
    seq,
    backward = True # will automatically segment and accumulate gradients when it detaches the memories
)

# after much data...

# calculating eval loss

with torch.no_grad():
    wrapper.eval()
    eval_loss = wrapper(seq)

# generating is as easy as

output = wrapper.generate(seq_len = 8192, prompt = seq[:, :1])

output.shape # (2, 8192 - 1)

Testing

Train an autoregressive enwik8

$ python train.py

Todo

  • detach_mems_every_num_segments hyperparameter is too confusing, get rid of it
  • experiment with enhanced recurrence, perhaps with a linear projection (talking heads on kv or linear projection on k, v separately) before sending the memories to the layer before
  • working example with enwik8

Citations

@inproceedings{Munkhdalai2024LeaveNC,
    title   = {Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention},
    author  = {Tsendsuren Munkhdalai and Manaal Faruqui and Siddharth Gopal},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:269033427}
}
@article{Yang2024ParallelizingLT,
    title   = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length},
    author  = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2406.06484},
    url     = {https://api.semanticscholar.org/CorpusID:270371554}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

infini_transformer_pytorch-0.2.1.tar.gz (36.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

infini_transformer_pytorch-0.2.1-py3-none-any.whl (10.4 kB view details)

Uploaded Python 3

File details

Details for the file infini_transformer_pytorch-0.2.1.tar.gz.

File metadata

File hashes

Hashes for infini_transformer_pytorch-0.2.1.tar.gz
Algorithm Hash digest
SHA256 ba31c50080255fbe6e724b399567d8c0a79b2956bf1573558e7075bdf00cc87a
MD5 07724d17b27cc6e877eaa88797e6e415
BLAKE2b-256 ca3b41678f7924fee44bc031232adda900c0ea53a381bd6ae496f451edc0f7b5

See more details on using hashes here.

File details

Details for the file infini_transformer_pytorch-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for infini_transformer_pytorch-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3c3f49a177042713bd630361855eca9d71d6949a14bb3fcec8ad72e857974b5c
MD5 0b554c5dc4d8ab6734505e5d8936af66
BLAKE2b-256 e8fbc1dfd817e5e221cf2406401ed5fc0ea78b8351aded74b02fa991a4fb4302

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page