Skip to main content

Unofficial PyTorch implementation of Titans: Learning to Memorize at Test Time

Project description

Unofficial Implementation of Titans: Learning to Memorize at Test Time

PyPI version Python 3.8+ License: MIT

This is an unofficial PyTorch implementation of the paper "Titans: Learning to Memorize at Test Time" by Ali Behrouz, Peilin Zhong, and Vahab Mirrokni.

Overview

Titans is a novel neural architecture that combines attention-based short-term memory with a neural long-term memory module. The architecture addresses the limitations of both recurrent models (which compress data into fixed-size memory) and attention mechanisms (which have quadratic complexity).

Key Features

  • Neural Long-term Memory: A module that learns to memorize historical context
  • Persistent Memory: Learnable tokens that encode task-specific knowledge
  • Three Architectural Variants:
    • MAC (Memory as Context): Uses memory as context for attention
    • MAG (Memory as Gate): Combines memory with core branch using gating
    • MAL (Memory as Layer): Integrates memory as a separate layer

Installation

From PyPI (Recommended)

pip install titans-unofficial

From Source (Development)

# Clone the repository
git clone https://github.com/Shehryar718/titans-unofficial.git
cd titans-unofficial

# Install in development mode with all extras
pip install -e ".[dev]"

Requirements

  • Python 3.10+
  • PyTorch 2.0+
  • transformers (for tokenization)
  • numpy
  • pytest (for running tests)

Project Structure

titans-unofficial/
├── titans/
│   ├── __init__.py
│   ├── models/
│   │   ├── titans_base.py    # Base class for all variants
│   │   ├── titans_mac.py     # Memory as Context implementation
│   │   ├── titans_mag.py     # Memory as Gate implementation
│   │   └── titans_mal.py     # Memory as Layer implementation
│   └── utils/
│       ├── memory.py         # Neural Memory Module
│       ├── attention.py      # Attention mechanisms
│       └── persistent_memory.py  # Persistent Memory implementation
├── examples/
│   ├── text_classification.py  # Text classification example
│   ├── language_modeling.py    # Language modeling example
│   └── fine_tuning.py         # Fine-tuning example
├── pytests/
│   └── test_memory.py         # Tests for memory module
├── requirements.txt
├── LICENSE
└── README.md

Usage

Text Classification

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.text_classification import TitansForClassification

# Initialize model
model = TitansForClassification(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    num_classes=2,
    memory_depth=2,
    persistent_tokens=8,
    window_size=16,
    model_type="mal"  # Choose from: "mac", "mag", "mal"
)
# Train and evaluate
python examples/text_classification.py

Language Modeling

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.language_modeling import TitansForLanguageModeling

# Initialize model
model = TitansForLanguageModeling(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    memory_depth=2,
    persistent_tokens=16,
    window_size=128,
    model_type="mac"
)
# Train and generate text
python examples/fine_tuning.py

Architecture Details

Neural Memory Module

The neural memory module consists of:

  • Key/Value/Query projections for memory access
  • Multi-layer perceptron for memory processing
  • Momentum-based update mechanism with configurable parameters
  • Weight decay for forgetting mechanism
  • Gradient scaling for numerical stability

Variants

  1. MAC (Memory as Context)

    • Memory output serves as additional context
    • Efficient for tasks requiring long-range dependencies
    • Parallel processing with chunked computation
    • Configurable chunk size and parallel processing
  2. MAG (Memory as Gate)

    • Gating mechanism to combine memory with core processing
    • Adaptive balance between short and long-term memory
    • Enhanced numerical stability
    • Improved gradient flow through gating
  3. MAL (Memory as Layer)

    • Memory integrated as a separate layer
    • Direct memory access at each layer
    • Sliding window attention for efficiency
    • Layer-wise memory updates

Example Tasks

The repository includes implementations for:

  • Text Classification (Binary and multi-class)
  • Language Modeling with test-time adaptation
  • Fine-tuning with early stopping

Each example demonstrates different aspects of the Titans architecture:

  • Memory reset between epochs for fresh adaptation
  • Efficient batch processing with dynamic batching
  • Gradient scaling for numerical stability
  • Early stopping and model checkpointing
  • Proper memory state management

Testing

# Run all tests
pytest pytests/

# Run specific test file
pytest pytests/test_memory.py

Citation

This repository provides an unofficial implementation of the Titans architecture.
If you reference this work, please cite the original paper:

@article{behrouz2024titans,
  title={Titans: Learning to Memorize at Test Time},
  author={Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
  journal={arXiv preprint arXiv:2501.00663},
  year={2024}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

titans_unofficial-1.0.5.tar.gz (27.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

titans_unofficial-1.0.5-py3-none-any.whl (4.9 kB view details)

Uploaded Python 3

File details

Details for the file titans_unofficial-1.0.5.tar.gz.

File metadata

  • Download URL: titans_unofficial-1.0.5.tar.gz
  • Upload date:
  • Size: 27.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for titans_unofficial-1.0.5.tar.gz
Algorithm Hash digest
SHA256 9e34185a754e99b4e4e50f1fb88c6101d68ed2c2f134398c44655b832f4a6bb7
MD5 1e3dba417c5bfe34f0810b89524358af
BLAKE2b-256 d87835324ea86a905131fcb5598b29dd8f56dab86c258c8ddf9d76daeacd6fc7

See more details on using hashes here.

File details

Details for the file titans_unofficial-1.0.5-py3-none-any.whl.

File metadata

File hashes

Hashes for titans_unofficial-1.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 c393ea76d341effd6eb3420687be0a45dff77ba5d0f466e08518c7021bc9f86e
MD5 b122197f0a69b6f23e82ab8ffd921502
BLAKE2b-256 732bd34a5fa36ef9ce656732200ab11dcc483cec483081f19fb4d6ec9bf1aa42

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page