Unofficial PyTorch implementation of Titans: Learning to Memorize at Test Time
Project description
Unofficial Implementation of Titans: Learning to Memorize at Test Time
This is an unofficial PyTorch implementation of the paper "Titans: Learning to Memorize at Test Time" by Ali Behrouz, Peilin Zhong, and Vahab Mirrokni.
Overview
Titans is a novel neural architecture that combines attention-based short-term memory with a neural long-term memory module. The architecture addresses the limitations of both recurrent models (which compress data into fixed-size memory) and attention mechanisms (which have quadratic complexity).
Key Features
- Neural Long-term Memory: A module that learns to memorize historical context
- Persistent Memory: Learnable tokens that encode task-specific knowledge
- Three Architectural Variants:
- MAC (Memory as Context): Uses memory as context for attention
- MAG (Memory as Gate): Combines memory with core branch using gating
- MAL (Memory as Layer): Integrates memory as a separate layer
Installation
From PyPI (Recommended)
pip install titans-unofficial
From Source (Development)
# Clone the repository
git clone https://github.com/Shehryar718/titans-unofficial.git
cd titans-unofficial
# Install in development mode with all extras
pip install -e ".[dev]"
Requirements
- Python 3.8+
- PyTorch 2.0+
- transformers (for tokenization)
- numpy
- pytest (for running tests)
Project Structure
titans-unofficial/
├── titans/
│ ├── __init__.py
│ ├── models/
│ │ ├── titans_base.py # Base class for all variants
│ │ ├── titans_mac.py # Memory as Context implementation
│ │ ├── titans_mag.py # Memory as Gate implementation
│ │ └── titans_mal.py # Memory as Layer implementation
│ └── utils/
│ ├── memory.py # Neural Memory Module
│ ├── attention.py # Attention mechanisms
│ └── persistent_memory.py # Persistent Memory implementation
├── examples/
│ ├── text_classification.py # Text classification example
│ ├── language_modeling.py # Language modeling example
│ └── fine_tuning.py # Fine-tuning example
├── pytests/
│ └── test_memory.py # Tests for memory module
├── requirements.txt
├── LICENSE
└── README.md
Usage
Text Classification
from titans import TitansMAC, TitansMAG, TitansMAL
from examples.text_classification import TitansForClassification
# Initialize model
model = TitansForClassification(
vocab_size=30000,
d_model=128,
n_layers=2,
n_heads=4,
num_classes=2,
memory_depth=2,
persistent_tokens=8,
window_size=16,
model_type="mal" # Choose from: "mac", "mag", "mal"
)
# Train and evaluate
python examples/text_classification.py
Language Modeling
from titans import TitansMAC, TitansMAG, TitansMAL
from examples.language_modeling import TitansForLanguageModeling
# Initialize model
model = TitansForLanguageModeling(
vocab_size=30000,
d_model=128,
n_layers=2,
n_heads=4,
memory_depth=2,
persistent_tokens=16,
window_size=128,
model_type="mac"
)
# Train and generate text
python examples/fine_tuning.py
Architecture Details
Neural Memory Module
The neural memory module consists of:
- Key/Value/Query projections for memory access
- Multi-layer perceptron for memory processing
- Momentum-based update mechanism with configurable parameters
- Weight decay for forgetting mechanism
- Gradient scaling for numerical stability
Variants
-
MAC (Memory as Context)
- Memory output serves as additional context
- Efficient for tasks requiring long-range dependencies
- Parallel processing with chunked computation
- Configurable chunk size and parallel processing
-
MAG (Memory as Gate)
- Gating mechanism to combine memory with core processing
- Adaptive balance between short and long-term memory
- Enhanced numerical stability
- Improved gradient flow through gating
-
MAL (Memory as Layer)
- Memory integrated as a separate layer
- Direct memory access at each layer
- Sliding window attention for efficiency
- Layer-wise memory updates
Example Tasks
The repository includes implementations for:
- Text Classification (Binary and multi-class)
- Language Modeling with test-time adaptation
- Fine-tuning with early stopping
Each example demonstrates different aspects of the Titans architecture:
- Memory reset between epochs for fresh adaptation
- Efficient batch processing with dynamic batching
- Gradient scaling for numerical stability
- Early stopping and model checkpointing
- Proper memory state management
Testing
# Run all tests
pytest pytests/
# Run specific test file
pytest pytests/test_memory.py
Citation
This repository provides an unofficial implementation of the Titans architecture.
If you reference this work, please cite the original paper:
@article{behrouz2024titans,
title={Titans: Learning to Memorize at Test Time},
author={Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
journal={arXiv preprint arXiv:2501.00663},
year={2024}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file titans_unofficial-1.0.1.tar.gz.
File metadata
- Download URL: titans_unofficial-1.0.1.tar.gz
- Upload date:
- Size: 26.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
66592669201deefb9c0d017509e32180aa0897ac191c0368c560227e22134478
|
|
| MD5 |
8501cc53a450e8f4a06b1283d090ec39
|
|
| BLAKE2b-256 |
3ab9c7746250dd4c68e6e354a204e41e064718321856656ccee3269e06061b0e
|
File details
Details for the file titans_unofficial-1.0.1-py3-none-any.whl.
File metadata
- Download URL: titans_unofficial-1.0.1-py3-none-any.whl
- Upload date:
- Size: 4.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.16
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eb9e113e6cc778c32117ad4f90bed1c7139b6ee48bb1ee196ce93ff484aa6167
|
|
| MD5 |
babf76d682dd7f49a563b89dea755fdc
|
|
| BLAKE2b-256 |
4fe88269f297b94fe8b29dd44c704cdb7d4e88ea0474b2bfc8f5b7005b31a865
|