Skip to main content

Unofficial PyTorch implementation of Titans: Learning to Memorize at Test Time

Project description

Unofficial Implementation of Titans: Learning to Memorize at Test Time

PyPI version Python 3.10+ License: MIT

This is an unofficial PyTorch implementation of the paper "Titans: Learning to Memorize at Test Time" by Ali Behrouz, Peilin Zhong, and Vahab Mirrokni.

Overview

Titans is a novel neural architecture that combines attention-based short-term memory with a neural long-term memory module. The architecture addresses the limitations of both recurrent models (which compress data into fixed-size memory) and attention mechanisms (which have quadratic complexity).

Key Features

  • Neural Long-term Memory: A module that learns to memorize historical context
  • Persistent Memory: Learnable tokens that encode task-specific knowledge
  • Three Architectural Variants:
    • MAC (Memory as Context): Uses memory as context for attention
    • MAG (Memory as Gate): Combines memory with core branch using gating
    • MAL (Memory as Layer): Integrates memory as a separate layer

Installation

From PyPI (Recommended)

pip install titans-unofficial

From Source (Development)

# Clone the repository
git clone https://github.com/Shehryar718/titans-unofficial.git
cd titans-unofficial

# Install in development mode with all extras
pip install -e ".[dev]"

Requirements

  • Python 3.10+
  • PyTorch 2.0+
  • transformers (for tokenization)
  • numpy
  • pytest (for running tests)

Project Structure

titans-unofficial/
├── titans/
│   ├── __init__.py
│   ├── models/
│   │   ├── titans_base.py    # Base class for all variants
│   │   ├── titans_mac.py     # Memory as Context implementation
│   │   ├── titans_mag.py     # Memory as Gate implementation
│   │   └── titans_mal.py     # Memory as Layer implementation
│   └── utils/
│       ├── memory.py         # Neural Memory Module
│       ├── attention.py      # Attention mechanisms
│       └── persistent_memory.py  # Persistent Memory implementation
├── examples/
│   ├── text_classification.py  # Text classification example
│   ├── language_modeling.py    # Language modeling example
│   └── fine_tuning.py         # Fine-tuning example
├── pytests/
│   └── test_memory.py         # Tests for memory module
├── requirements.txt
├── LICENSE
└── README.md

Usage

Text Classification

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.text_classification import TitansForClassification

# Initialize model
model = TitansForClassification(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    num_classes=2,
    memory_depth=2,
    persistent_tokens=8,
    window_size=16,
    model_type="mal"  # Choose from: "mac", "mag", "mal"
)
# Train and evaluate
python examples/text_classification.py

Language Modeling

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.language_modeling import TitansForLanguageModeling

# Initialize model
model = TitansForLanguageModeling(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    memory_depth=2,
    persistent_tokens=16,
    window_size=128,
    model_type="mac"
)
# Train and generate text
python examples/fine_tuning.py

Architecture Details

Neural Memory Module

The neural memory module consists of:

  • Key/Value/Query projections for memory access
  • Multi-layer perceptron for memory processing
  • Momentum-based update mechanism with configurable parameters
  • Weight decay for forgetting mechanism
  • Gradient scaling for numerical stability

Variants

  1. MAC (Memory as Context)

    • Memory output serves as additional context
    • Efficient for tasks requiring long-range dependencies
    • Parallel processing with chunked computation
    • Configurable chunk size and parallel processing
  2. MAG (Memory as Gate)

    • Gating mechanism to combine memory with core processing
    • Adaptive balance between short and long-term memory
    • Enhanced numerical stability
    • Improved gradient flow through gating
  3. MAL (Memory as Layer)

    • Memory integrated as a separate layer
    • Direct memory access at each layer
    • Sliding window attention for efficiency
    • Layer-wise memory updates

Example Tasks

The repository includes implementations for:

  • Text Classification (Binary and multi-class)
  • Language Modeling with test-time adaptation
  • Fine-tuning with early stopping

Each example demonstrates different aspects of the Titans architecture:

  • Memory reset between epochs for fresh adaptation
  • Efficient batch processing with dynamic batching
  • Gradient scaling for numerical stability
  • Early stopping and model checkpointing
  • Proper memory state management

Testing

# Run all tests
pytest pytests/

# Run specific test file
pytest pytests/test_memory.py

Citation

This repository provides an unofficial implementation of the Titans architecture.
If you reference this work, please cite the original paper:

@article{behrouz2024titans,
  title={Titans: Learning to Memorize at Test Time},
  author={Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
  journal={arXiv preprint arXiv:2501.00663},
  year={2024}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

titans_unofficial-1.0.8.tar.gz (27.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

titans_unofficial-1.0.8-py3-none-any.whl (17.4 kB view details)

Uploaded Python 3

File details

Details for the file titans_unofficial-1.0.8.tar.gz.

File metadata

  • Download URL: titans_unofficial-1.0.8.tar.gz
  • Upload date:
  • Size: 27.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for titans_unofficial-1.0.8.tar.gz
Algorithm Hash digest
SHA256 5e308839cb3b1189b8680035ccd4b78a74f390c8146d43045d0d7a809bf45e1c
MD5 9e72d5b7b76327a166b73ada0824f8b1
BLAKE2b-256 a45b9cf3fb5eeee9b4de537cd0dfb5e42a4afdec1e42769740461ceace4975c0

See more details on using hashes here.

File details

Details for the file titans_unofficial-1.0.8-py3-none-any.whl.

File metadata

File hashes

Hashes for titans_unofficial-1.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 7cdf6bcb06b131ca85e52e7bd696ee7c7efc91896d76ea2b16488e9c3654441f
MD5 2df8d951c927f219cfeb99c5a8674e0f
BLAKE2b-256 42b78ddf9e05d65834ecc40d163736517a4214a6fd6adba108ea09a884aacc06

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page