Skip to main content

TPTT : Transforming Pretrained Transformers into Titans

Project description

😊 TPTT

arXiv PyPI Release Documentation HuggingFace

Transforming Pretrained Transformers into Titans

TPTT is a modular Python library designed to inject efficient linearized attention (LiZA) mechanisms-such as Memory as Gate (described in Titans)-into pretrained transformers 🤗.


Features

  • Flexible Attention Injection: Seamlessly wrap and augment standard Transformer attention layers with linearized attention variants for latent memory.
  • Support for Linear Attention: Includes implementations of DeltaNet and DeltaProduct with optional recurrent nonlinearity between chunks.
  • Modular Design: Easily extend or customize operators and integration strategies.
  • Compatibility: Designed to integrate with Hugging Face Transformers and similar PyTorch models.
  • Low-Compute Alignment: Requires only lightweight fine-tuning after injection, enabling efficient memory integration without heavy retraining.

[!IMPORTANT] After injecting the LiZA module, the model requires fine-tuning to properly align and effectively utilize the memory mechanism.

overview

Note: The Order 2 Delta-Product attention mechanism is equally expressive as Titans.

Installation and Usage

pip install tptt

Titanesque Documentation

  • TPTT-LiZA_Training:
    Instructions for training TPTT-based models with LoRA and advanced memory management.

  • TPTT_LiZA_Evaluation:
    Guide for evaluating language models with LightEval and Hugging Face Transformers.

  • TPTT_LiZA_FromScratch:
    Integrating the LinearAttention module into Pytorch deep learning projects.

Basic usage :

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import tptt
from tptt import save_tptt_safetensors, get_tptt_model, load_tptt_safetensors
from torch import nn

##### Transforming into Titans (Tptt)
base_model_path = "Qwen/Qwen2.5-1.5B"
base_config = AutoConfig.from_pretrained(base_model_path)
base_model_name = "Qwen/Qwen2.5-1.5B"
tptt_config = tptt.TpttConfig(
    base_model_config=base_config,
    base_model_name= base_model_name, 
    #lora_config=lora_config,

)
model = tptt.TpttModel(config)
# manual local save
save_tptt_safetensors(model, path, name)

##### Pretrained Titans from Transformer
repo_id = "ffurfaro/Titans-Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(repo_id, trust_remote_code=True)

##### More custom for other Model (BERT, ViT, etc.)
model, linear_cache = get_tptt_model(model, config) # you can activate Bidirectional
model = load_tptt_safetensors(repo_or_path, model) # from saved LoRA only

##### Using LinearAttention from scratch
layers = nn.ModuleList([
    tptt.LinearAttention(hidden_dim=64, num_heads=4,)
    for _ in range(num_layers)])

Some scripts are available here


Results examples

plot

More détails in paper.

Development

  • Code is organized into modular components under the src/tptt directory.
  • Use pytest for testing and sphinx for documentation. See on this link🔥
  • Contributions and feature requests are welcome!

Requirements

  • Python 3.11+
  • PyTorch
  • einops
  • Transformers
  • Peft

See requirements.txt for the full list.


Docker Usage

Build and run TPTT with Docker:

# Build the image
docker build -t tptt .

# Run training (with GPU support)
docker run -it --gpus all \
  -v $(pwd)/data:/data \
  -v $(pwd)/outputs:/outputs \
  tptt python -m train \
    --model_name "meta-llama/Llama-3.2-1B" \
    --method delta_rule \
    --mag_weight 0.5

For more details, see the Dockerfile.

Acknowledgements

Discovering the OpenSparseLLMs/Linearization (🚀 linear-flash-attention-based) project inspired this work and motivated me to create a fully modular, Delta-rule style PyTorch version.

Citation

If you use TPTT in your academic work, please cite:

@article{furfaro2025tptt,
  title={TPTT: Transforming Pretrained Transformers into Titans},
  author={Furfaro, Fabien},
  journal={arXiv preprint arXiv:2506.17671},
  year={2025}
}

Contact

For questions or support, please open an issue on the GitHub repository or contact the maintainer.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

tptt-0.12.1.tar.gz (30.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

tptt-0.12.1-py3-none-any.whl (31.9 kB view details)

Uploaded Python 3

File details

Details for the file tptt-0.12.1.tar.gz.

File metadata

  • Download URL: tptt-0.12.1.tar.gz
  • Upload date:
  • Size: 30.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.11.11 Linux/6.8.0-84-generic

File hashes

Hashes for tptt-0.12.1.tar.gz
Algorithm Hash digest
SHA256 a21a32afef508da7a1b9738b6c7ccf87c70da8b673b518268d00f0c7bf8d96c4
MD5 fdd5790e22a76cbd72914842c3be06d8
BLAKE2b-256 5e3050dfbac82160ea49af750d770b52b23c245080a1acfa198c338ea0340a0a

See more details on using hashes here.

File details

Details for the file tptt-0.12.1-py3-none-any.whl.

File metadata

  • Download URL: tptt-0.12.1-py3-none-any.whl
  • Upload date:
  • Size: 31.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.11.11 Linux/6.8.0-84-generic

File hashes

Hashes for tptt-0.12.1-py3-none-any.whl
Algorithm Hash digest
SHA256 dbe64f1e6a0585f886b9710991ad1c89aa9e8c118c31f60f68ea26b2a2db9cda
MD5 2d1140b593d48363587ec9550ba860e7
BLAKE2b-256 fadbe92ff6440e7ca63d14cfba65e0f702cd205889c02a1346f385e694c294b0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page