Skip to main content

PyTorch implementation of Titans: Learning to Memorize at Test Time (Behrouz, Zhong & Mirrokni, 2024)

Project description

Titans: Learning to Memorize at Test Time

Python PyTorch GitHub release License: MIT arXiv

A clean, highly-optimized PyTorch implementation of the Titans architecture from:

Titans: Learning to Memorize at Test Time
Ali Behrouz, Peilin Zhong, Vahab Mirrokni — Google Research, 2024
arXiv:2501.00663

Titans Architecture Overview


What's Inside

Module Description
NeuralMemory Deep MLP that learns to memorize via gradient descent with momentum + weight-decay forgetting (§3)
PersistentMemory Learnable task-knowledge tokens prepended to every sequence (§3.3)
TitansMAC Memory as a Context — retrieves long-term memory as prefix to attention window (§4.1)
TitansMAG Memory as a Gate — SWA ⊗ NeuralMemory gated branch (§4.2)
TitansMAL Memory as a Layer — sequential LMM → SWA stack (§4.3)
TitansLMM Standalone LMM — neural memory without attention (§4.3)

Installation

# Install via PyPI (Recommended)
pip install titans-memory

# Or clone and install locally for development
git clone https://github.com/Neuranox/titans-memory.git
cd titans-memory
pip install -e .

Quick Start

import torch
from titans import TitansMAC, TitansMAG, TitansMAL, TitansLMM
from titans.utils import TitansConfig, build_model, count_parameters

# ── Build from config ──────────────────────────────────────────────────
cfg   = TitansConfig.small(variant="MAC")   # ~170 M params
cfg.vocab_size = 32_000
model = build_model(cfg)
print(f"Parameters: {count_parameters(model):,}")

# ── Forward pass ───────────────────────────────────────────────────────
input_ids = torch.randint(0, 32_000, (2, 512))
labels    = input_ids.clone()

out = model(input_ids, labels=labels)
print(out["logits"].shape)   # (2, 512, 32000)
print(out["loss"].item())

# ── Generation ─────────────────────────────────────────────────────────
prompt    = torch.randint(0, 32_000, (1, 8))
generated = model.generate(prompt, max_new_tokens=50, top_k=50)

All Four Variants

VOCAB = 32_000
D     = 512

models = {
    "LMM": TitansLMM(VOCAB, d_model=D, n_layers=12, mem_layers=2),
    "MAC": TitansMAC(VOCAB, d_model=D, n_layers=12, mem_layers=2, chunk_size=128),
    "MAG": TitansMAG(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
    "MAL": TitansMAL(VOCAB, d_model=D, n_layers=12, mem_layers=2, window=512),
}

TitansConfig — Paper-Scale Presets

from titans.utils import TitansConfig, build_model

cfg = TitansConfig.tiny(variant="MAC")    # ~30 M  — quick experiments
cfg = TitansConfig.small(variant="MAC")   # ~170 M — paper Table 1
cfg = TitansConfig.medium(variant="MAC")  # ~340 M — paper Table 1
cfg = TitansConfig.large(variant="MAC")   # ~760 M — paper Table 1

# JSON save / load
cfg.to_json("config.json")
cfg = TitansConfig.from_json("config.json")

Training

from titans.utils.training import build_optimizer, get_cosine_schedule_with_warmup

optim = build_optimizer(model, lr=4e-4, weight_decay=0.1)          # AdamW, no wd on bias/norm
sched = get_cosine_schedule_with_warmup(optim,
            warmup_steps=2000, total_steps=100_000, min_lr_ratio=0.1)

for batch in dataloader:
    out  = model(batch["input_ids"], labels=batch["labels"])
    out["loss"].backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optim.step(); sched.step(); optim.zero_grad()

See examples/02_training_loop.py for a complete runnable example.


Neural Memory — Key Equations

Component Equation Description
Momentary surprise ∇ℓ(M_{t-1}; x_t) How unexpected is x_t?
Surprise with momentum S_t = η_t S_{t-1} − θ_t ∇ℓ Eq. 10 — carries information flow
Forgetting gate M_t = (1−α_t) M_{t-1} + S_t Eq. 13 — weight-decay style
Retrieval y_t = M*(q_t) Eq. 15 — inference, no update

Running Tests

cd "F:\Titan Model"
pip install -e .[dev]
pytest

Project Structure

Titan Model/
├── titans/
│   ├── __init__.py           ← public API
│   ├── memory/
│   │   ├── neural_memory.py  ← NeuralMemory (LMM core)
│   │   └── persistent_memory.py
│   ├── models/
│   │   ├── lmm.py            ← TitansLMM
│   │   ├── mac.py            ← TitansMAC
│   │   ├── mag.py            ← TitansMAG
│   │   └── mal.py            ← TitansMAL
│   ├── ops/
│   │   ├── scan.py           ← parallel associative scan
│   │   └── attention.py      ← causal + sliding-window attention
│   └── utils/
│       ├── config.py         ← TitansConfig dataclass
│       ├── factory.py        ← build_model()
│       └── training.py       ← optimizer + LR schedule helpers
├── tests/
│   ├── test_scan.py
│   ├── test_memory.py
│   └── test_models.py
├── examples/
│   ├── 01_quickstart.py
│   ├── 02_training_loop.py
│   └── 03_memory_standalone.py
├── pyproject.toml
├── setup.py
└── README.md

Citation

@article{behrouz2024titans,
  title   = {Titans: Learning to Memorize at Test Time},
  author  = {Behrouz, Ali and Zhong, Peilin and Mirrokni, Vahab},
  journal = {arXiv preprint arXiv:2501.00663},
  year    = {2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

titans_memory-0.2.0.tar.gz (24.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

titans_memory-0.2.0-py3-none-any.whl (29.0 kB view details)

Uploaded Python 3

File details

Details for the file titans_memory-0.2.0.tar.gz.

File metadata

  • Download URL: titans_memory-0.2.0.tar.gz
  • Upload date:
  • Size: 24.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for titans_memory-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0a3a030f9059b42c1c6d7c84835f8896eed87f250f0cb19448e1ae6922b57f9e
MD5 2eed48222e10b534060ceb3f8e120fd7
BLAKE2b-256 90a69eaeb0395b791c610166fe00c1a649a042c83dc914a743fd29594941236f

See more details on using hashes here.

File details

Details for the file titans_memory-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: titans_memory-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 29.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.9

File hashes

Hashes for titans_memory-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2a1a936261859cdc9e342bf1e7daad3c54e300f36c56ba92843e74983db05096
MD5 6ad89ce42839b5b04001341f1158bd10
BLAKE2b-256 6471b64d38f4a644811e6fed0053f1d2284d1e4f8499b1622cd65cb2acf6e19f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page