PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)
Project description
Titans: Learning to Memorize at Test Time
A clean, highly-optimized PyTorch implementation of the Titans architecture from:
Titans: Learning to Memorize at Test Time
Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
arXiv:2501.00663
What's Inside
| Module | Description |
|---|---|
NeuralMemory |
Deep MLP that learns to memorize via gradient descent with momentum + weight-decay forgetting (§3) |
PersistentMemory |
Learnable task-knowledge tokens prepended to every sequence (§3.3) |
TitansMAC |
Memory as a Context — retrieves long-term memory as prefix to attention window (§4.1) |
TitansMAG |
Memory as a Gate — SWA ⊗ NeuralMemory gated branch (§4.2) |
TitansMAL |
Memory as a Layer — sequential LMM → SWA stack (§4.3) |
TitansLMM |
Standalone LMM — neural memory without attention (§4.3) |
Installation
# Install directly from GitHub
pip install git+https://github.com/Neuranox/titans-memory.git
# Or clone and install locally (editable — recommended for development)
git clone https://github.com/Neuranox/titans-memory.git
cd titans-memory
pip install -e .
Quick Start
import torch
from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
from titans.utils import TitansConfig, build_model, count_parameters
# ── Build from config ──────────────────────────────────────────────────
cfg = TitansConfig.small(variant="MAC") # ~170 M params
cfg.vocab_size = 32_000
model = build_model(cfg)
print(f"Parameters: {count_parameters(model):,}")
# ── Forward pass ───────────────────────────────────────────────────────
input_ids = torch.randint(0, 32_000, (2, 512))
labels = input_ids.clone()
out = model(input_ids, labels=labels)
print(out["logits"].shape) # (2, 512, 32000)
print(out["loss"].item())
# ── Generation ─────────────────────────────────────────────────────────
prompt = torch.randint(0, 32_000, (1, 8))
generated = model.generate(prompt, max_new_tokens=50, top_k=50)
All Four Variants
VOCAB = 32_000
D = 512
models = {
"LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
"MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
"MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
"MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
}
TitansConfig — Paper-Scale Presets
from titans.utils import TitansConfig, build_model
cfg = TitansConfig.tiny(variant="MAC") # ~30 M — quick experiments
cfg = TitansConfig.small(variant="MAC") # ~170 M — paper Table 1
cfg = TitansConfig.medium(variant="MAC") # ~340 M — paper Table 1
cfg = TitansConfig.large(variant="MAC") # ~760 M — paper Table 1
# JSON save / load
cfg.to_json("config.json")
cfg = TitansConfig.from_json("config.json")
Training
from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup
optim = build_optimizer(model, lr=4e-4, weight_decay=0.1) # AdamW, no wd on bias/norm
sched = get_cosine_schedule_with_warmup(optim,
warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)
for batch in dataloader:
out = model(batch["input_ids"], labels=batch["labels"])
out["loss"].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step(); sched.step(); optim.zero_grad()
See examples/02_training_loop.py for a complete runnable example.
Architecture Overview
Titans (MAC) — Memory as a Context
───────────────────────────────────────────────────────────
For each segment S^(t):
h_t = M*_{t-1}(q_t) ← retrieve long-term memory
S̃^(t) = [P || h_t || S^(t)] ← augment with persistent + history
y_t = Attention(S̃^(t)) ← full causal attention over window
M_t = M_{t-1}.update(y_t) ← write: gradient descent w/ momentum
o_t = y_t ⊗ M*_t(y_t) ← gated output
Titans (MAG) — Memory as a Gate
───────────────────────────────────────────────────────────
x̃ = [P || x]
y = SW-Attn(x̃) ← precise short-term memory (sliding window)
o = y ⊗ M(x̃) ← gated with neural long-term memory
Titans (MAL) — Memory as a Layer
───────────────────────────────────────────────────────────
x̃ = [P || x]
y = M(x̃) ← memory compresses context
o = SW-Attn(y) ← attention refines compressed representation
Neural Memory — Key Equations
| Component | Equation | Description |
|---|---|---|
| Momentary surprise | ∇ℓ(M_{t-1}; x_t) |
How unexpected is x_t? |
| Surprise with momentum | S_t = η_t S_{t-1} − θ_t ∇ℓ |
Eq. 10 — carries information flow |
| Forgetting gate | M_t = (1−α_t) M_{t-1} + S_t |
Eq. 13 — weight-decay style |
| Retrieval | y_t = M*(q_t) |
Eq. 15 — inference, no update |
Running Tests
cd "F:\Titan Model"
pip install -e .[dev]
pytest
Project Structure
Titan Model/
├── titans/
│ ├── __init__.py ← public API
│ ├── memory/
│ │ ├── neural_memory.py ← NeuralMemory (LMM core)
│ │ └── persistent_memory.py
│ ├── models/
│ │ ├── lmm.py ← TitansLMM
│ │ ├── mac.py ← TitansMAC
│ │ ├── mag.py ← TitansMAG
│ │ └── mal.py ← TitansMAL
│ ├── ops/
│ │ ├── scan.py ← parallel associative scan
│ │ └── attention.py ← causal + sliding-window attention
│ └── utils/
│ ├── config.py ← TitansConfig dataclass
│ ├── factory.py ← build_model()
│ └── training.py ← optimizer + LR schedule helpers
├── tests/
│ ├── test_scan.py
│ ├── test_memory.py
│ └── test_models.py
├── examples/
│ ├── 01_quickstart.py
│ ├── 02_training_loop.py
│ └── 03_memory_standalone.py
├── pyproject.toml
├── setup.py
└── README.md
Citation
@article{behrouz2024titans,
title = {Titans: Learning to Memorize at Test Time},
author = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
journal = {arXiv preprint arXiv:2501.00663},
year = {2024}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file titans_memory-0.1.0.tar.gz.
File metadata
- Download URL: titans_memory-0.1.0.tar.gz
- Upload date:
- Size: 24.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
03531636084a6547ab7b69ecbcc63dbe184dd6fd5d6723afd39ecf4cd35efe8c
|
|
| MD5 |
e6770692e5461a81fc1591824ce906a8
|
|
| BLAKE2b-256 |
67cbdfdef2a9a936e0111dd16e1be114734f100af27248160afd363ce1123ce7
|
File details
Details for the file titans_memory-0.1.0-py3-none-any.whl.
File metadata
- Download URL: titans_memory-0.1.0-py3-none-any.whl
- Upload date:
- Size: 27.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.9
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
bab9f3cdd247f98351f3c234d81b048b9135ce1c57771e8037360202e81bcfa4
|
|
| MD5 |
28bc9e24dbed0d7123531e641458966e
|
|
| BLAKE2b-256 |
95081577060f2d657e3a73ec6d8578aedea2a25862992cdaf475733bf3e2bc88
|