Skip to main content

PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)

Project description

Titans: Learning to Memorize at Test Time

Python PyTorch GitHub release License: MIT arXiv

A clean, highly-optimized PyTorch implementation of the Titans architecture from:

Titans: Learning to Memorize at Test Time
Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
arXiv:2501.00663

Titans Architecture Overview


What's Inside

Module Description
NeuralMemory Deep MLP that learns to memorize via gradient descent with momentum + weight-decay forgetting (§3)
PersistentMemory Learnable task-knowledge tokens prepended to every sequence (§3.3)
TitansMAC Memory as a Context — retrieves long-term memory as prefix to attention window (§4.1)
TitansMAG Memory as a Gate — SWA ⊗ NeuralMemory gated branch (§4.2)
TitansMAL Memory as a Layer — sequential LMM → SWA stack (§4.3)
TitansLMM Standalone LMM — neural memory without attention (§4.3)

Installation

# Install directly from GitHub
pip install git+https://github.com/Neuranox/titans-memory.git

# Or clone and install locally (editable — recommended for development)
git clone https://github.com/Neuranox/titans-memory.git
cd titans-memory
pip install -e .

Quick Start

import torch
from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
from titans.utils import TitansConfig, build_model, count_parameters

# ── Build from config ──────────────────────────────────────────────────
cfg   = TitansConfig.small(variant="MAC")   # ~170 M params
cfg.vocab_size = 32_000
model = build_model(cfg)
print(f"Parameters: {count_parameters(model):,}")

# ── Forward pass ───────────────────────────────────────────────────────
input_ids = torch.randint(0, 32_000, (2, 512))
labels    = input_ids.clone()

out = model(input_ids, labels=labels)
print(out["logits"].shape)   # (2, 512, 32000)
print(out["loss"].item())

# ── Generation ─────────────────────────────────────────────────────────
prompt    = torch.randint(0, 32_000, (1, 8))
generated = model.generate(prompt, max_new_tokens=50, top_k=50)

All Four Variants

VOCAB = 32_000
D     = 512

models = {
    "LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
    "MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
    "MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
    "MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
}

TitansConfig — Paper-Scale Presets

from titans.utils import TitansConfig, build_model

cfg = TitansConfig.tiny(variant="MAC")    # ~30 M  — quick experiments
cfg = TitansConfig.small(variant="MAC")   # ~170 M — paper Table 1
cfg = TitansConfig.medium(variant="MAC")  # ~340 M — paper Table 1
cfg = TitansConfig.large(variant="MAC")   # ~760 M — paper Table 1

# JSON save / load
cfg.to_json("config.json")
cfg = TitansConfig.from_json("config.json")

Training

from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup

optim = build_optimizer(model, lr=4e-4, weight_decay=0.1)          # AdamW, no wd on bias/norm
sched = get_cosine_schedule_with_warmup(optim,
            warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)

for batch in dataloader:
    out  = model(batch["input_ids"], labels=batch["labels"])
    out["loss"].backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optim.step(); sched.step(); optim.zero_grad()

See examples/02_training_loop.py for a complete runnable example.


Architecture Overview

Titans Detail

Titans (MAC) — Memory as a Context
───────────────────────────────────────────────────────────
 For each segment S^(t):
   h_t     = M*_{t-1}(q_t)           ← retrieve long-term memory
   S̃^(t)  = [P || h_t || S^(t)]     ← augment with persistent + history
   y_t     = Attention(S̃^(t))        ← full causal attention over window
   M_t     = M_{t-1}.update(y_t)    ← write: gradient descent w/ momentum
   o_t     = y_t ⊗ M*_t(y_t)        ← gated output

Titans (MAG) — Memory as a Gate
───────────────────────────────────────────────────────────
   x̃  = [P || x]
   y   = SW-Attn(x̃)  ← precise short-term memory (sliding window)
   o   = y ⊗ M(x̃)   ← gated with neural long-term memory

Titans (MAL) — Memory as a Layer
───────────────────────────────────────────────────────────
   x̃  = [P || x]
   y   = M(x̃)         ← memory compresses context
   o   = SW-Attn(y)   ← attention refines compressed representation

Neural Memory — Key Equations

Component Equation Description
Momentary surprise ∇ℓ(M_{t-1}; x_t) How unexpected is x_t?
Surprise with momentum S_t = η_t S_{t-1} − θ_t ∇ℓ Eq. 10 — carries information flow
Forgetting gate M_t = (1−α_t) M_{t-1} + S_t Eq. 13 — weight-decay style
Retrieval y_t = M*(q_t) Eq. 15 — inference, no update

Running Tests

cd "F:\Titan Model"
pip install -e .[dev]
pytest

Project Structure

Titan Model/
├── titans/
│   ├── __init__.py           ← public API
│   ├── memory/
│   │   ├── neural_memory.py  ← NeuralMemory (LMM core)
│   │   └── persistent_memory.py
│   ├── models/
│   │   ├── lmm.py            ← TitansLMM
│   │   ├── mac.py            ← TitansMAC
│   │   ├── mag.py            ← TitansMAG
│   │   └── mal.py            ← TitansMAL
│   ├── ops/
│   │   ├── scan.py           ← parallel associative scan
│   │   └── attention.py      ← causal + sliding-window attention
│   └── utils/
│       ├── config.py         ← TitansConfig dataclass
│       ├── factory.py        ← build_model()
│       └── training.py       ← optimizer + LR schedule helpers
├── tests/
│   ├── test_scan.py
│   ├── test_memory.py
│   └── test_models.py
├── examples/
│   ├── 01_quickstart.py
│   ├── 02_training_loop.py
│   └── 03_memory_standalone.py
├── pyproject.toml
├── setup.py
└── README.md

Citation

@article{behrouz2024titans,
  title   = {Titans: Learning to Memorize at Test Time},
  author  = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
  journal = {arXiv preprint arXiv:2501.00663},
  year    = {2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

titans_memory-0.1.0.tar.gz (24.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

titans_memory-0.1.0-py3-none-any.whl (27.4 kB view details)

Uploaded Python 3

File details

Details for the file titans_memory-0.1.0.tar.gz.

File metadata

  • Download URL: titans_memory-0.1.0.tar.gz
  • Upload date:
  • Size: 24.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for titans_memory-0.1.0.tar.gz
Algorithm Hash digest
SHA256 03531636084a6547ab7b69ecbcc63dbe184dd6fd5d6723afd39ecf4cd35efe8c
MD5 e6770692e5461a81fc1591824ce906a8
BLAKE2b-256 67cbdfdef2a9a936e0111dd16e1be114734f100af27248160afd363ce1123ce7

See more details on using hashes here.

File details

Details for the file titans_memory-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: titans_memory-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 27.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for titans_memory-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 bab9f3cdd247f98351f3c234d81b048b9135ce1c57771e8037360202e81bcfa4
MD5 28bc9e24dbed0d7123531e641458966e
BLAKE2b-256 95081577060f2d657e3a73ec6d8578aedea2a25862992cdaf475733bf3e2bc88

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page