Skip to main content

Unofficial PyTorch implementation of Titans: Learning to Memorize at Test Time

Project description

Unofficial Implementation of Titans: Learning to Memorize at Test Time

PyPI version Python 3.10+ License: MIT

This is an unofficial PyTorch implementation of the paper "Titans: Learning to Memorize at Test Time" by Ali Behrouz, Peilin Zhong, and Vahab Mirrokni.

Overview

Titans is a novel neural architecture that combines attention-based short-term memory with a neural long-term memory module. The architecture addresses the limitations of both recurrent models (which compress data into fixed-size memory) and attention mechanisms (which have quadratic complexity).

Key Features

  • Neural Long-term Memory: A module that learns to memorize historical context
  • Persistent Memory: Learnable tokens that encode task-specific knowledge
  • Three Architectural Variants:
    • MAC (Memory as Context): Uses memory as context for attention
    • MAG (Memory as Gate): Combines memory with core branch using gating
    • MAL (Memory as Layer): Integrates memory as a separate layer

Installation

From PyPI (Recommended)

pip install titans-unofficial

From Source (Development)

# Clone the repository
git clone https://github.com/Shehryar718/titans-unofficial.git
cd titans-unofficial

# Install in development mode with all extras
pip install -e ".[dev]"

Requirements

  • Python 3.10+
  • PyTorch 2.0+
  • transformers (for tokenization)
  • numpy
  • pytest (for running tests)

Project Structure

titans-unofficial/
├── titans/
│   ├── __init__.py
│   ├── models/
│   │   ├── titans_base.py    # Base class for all variants
│   │   ├── titans_mac.py     # Memory as Context implementation
│   │   ├── titans_mag.py     # Memory as Gate implementation
│   │   └── titans_mal.py     # Memory as Layer implementation
│   └── utils/
│       ├── memory.py         # Neural Memory Module
│       ├── attention.py      # Attention mechanisms
│       └── persistent_memory.py  # Persistent Memory implementation
├── examples/
│   ├── text_classification.py  # Text classification example
│   ├── language_modeling.py    # Language modeling example
│   └── fine_tuning.py         # Fine-tuning example
├── pytests/
│   └── test_memory.py         # Tests for memory module
├── requirements.txt
├── LICENSE
└── README.md

Usage

Text Classification

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.text_classification import TitansForClassification

# Initialize model
model = TitansForClassification(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    num_classes=2,
    memory_depth=2,
    persistent_tokens=8,
    window_size=16,
    model_type="mal"  # Choose from: "mac", "mag", "mal"
)
# Train and evaluate
python examples/text_classification.py

Language Modeling

from titans import TitansMAC, TitansMAG, TitansMAL
from examples.language_modeling import TitansForLanguageModeling

# Initialize model
model = TitansForLanguageModeling(
    vocab_size=30000,
    d_model=128,
    n_layers=2,
    n_heads=4,
    memory_depth=2,
    persistent_tokens=16,
    window_size=128,
    model_type="mac"
)
# Train and generate text
python examples/fine_tuning.py

Architecture Details

Neural Memory Module

The neural memory module consists of:

  • Key/Value/Query projections for memory access
  • Multi-layer perceptron for memory processing
  • Momentum-based update mechanism with configurable parameters
  • Weight decay for forgetting mechanism
  • Gradient scaling for numerical stability

Variants

  1. MAC (Memory as Context)

    • Memory output serves as additional context
    • Efficient for tasks requiring long-range dependencies
    • Parallel processing with chunked computation
    • Configurable chunk size and parallel processing
  2. MAG (Memory as Gate)

    • Gating mechanism to combine memory with core processing
    • Adaptive balance between short and long-term memory
    • Enhanced numerical stability
    • Improved gradient flow through gating
  3. MAL (Memory as Layer)

    • Memory integrated as a separate layer
    • Direct memory access at each layer
    • Sliding window attention for efficiency
    • Layer-wise memory updates

Example Tasks

The repository includes implementations for:

  • Text Classification (Binary and multi-class)
  • Language Modeling with test-time adaptation
  • Fine-tuning with early stopping

Each example demonstrates different aspects of the Titans architecture:

  • Memory reset between epochs for fresh adaptation
  • Efficient batch processing with dynamic batching
  • Gradient scaling for numerical stability
  • Early stopping and model checkpointing
  • Proper memory state management

Testing

# Run all tests
pytest pytests/

# Run specific test file
pytest pytests/test_memory.py

Citation

This repository provides an unofficial implementation of the Titans architecture.
If you reference this work, please cite the original paper:

@article{behrouz2024titans,
  title={Titans: Learning to Memorize at Test Time},
  author={Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
  journal={arXiv preprint arXiv:2501.00663},
  year={2024}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

titans_unofficial-1.0.6.tar.gz (27.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

titans_unofficial-1.0.6-py3-none-any.whl (4.9 kB view details)

Uploaded Python 3

File details

Details for the file titans_unofficial-1.0.6.tar.gz.

File metadata

  • Download URL: titans_unofficial-1.0.6.tar.gz
  • Upload date:
  • Size: 27.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.16

File hashes

Hashes for titans_unofficial-1.0.6.tar.gz
Algorithm Hash digest
SHA256 c7a90b7f264e27b0e86a589cdeb46eb0226e111aeecd1982498627cc59e21bba
MD5 1f2ce74a2eea5a1842fc3e14b0617011
BLAKE2b-256 fcc0204b2f37086da7eb9d09e1a2508d22f729bf44d800da719231735305d0f0

See more details on using hashes here.

File details

Details for the file titans_unofficial-1.0.6-py3-none-any.whl.

File metadata

File hashes

Hashes for titans_unofficial-1.0.6-py3-none-any.whl
Algorithm Hash digest
SHA256 55a9351fffd6f04211e6a704ac692c8f23ebbc05acc18f6241c657c74cbb717b
MD5 ac689582dbcbff573c76e4d13b88f837
BLAKE2b-256 d11ecf0f5c816e4fa7b6e6b9c92b1a68af35033a60787ce01ba24c240a134d56

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page