Skip to main content

Titans architecture and MIRAS framework for test-time memorization in long-context sequence modeling

Project description

Titans-MIRAS

PyPI version Python 3.9+ PyTorch 2.0+ License: MIT

Implementation of the Titans-MIRAS system from Google Research, enabling test-time memorization and long-term memory in AI for long-context sequence modeling.

Overview

Titans introduces a neural long-term memory module that learns to memorize information as data streams in, enabling efficient handling of extremely long contexts (2M+ tokens). MIRAS provides a theoretical framework for designing memory mechanisms through four key components:

  1. Memory Architecture: Structure storing information (vector, matrix, or deep MLP)
  2. Attentional Bias: Learning objective determining what to prioritize
  3. Retention Gate: Regularizer balancing new learning vs. past knowledge
  4. Memory Algorithm: Optimization algorithm for updating memory

MIRAS Variants

Variant Description Use Case
DEFAULT Original Titans with MSE loss General purpose
YAAD Huber loss for robustness Noisy/outlier-heavy data
MONETA Generalized p-norms Enhanced expressivity
MEMORA Probability map constraints Maximum stability

Installation

pip install python-titans-miras

Or install from source:

git clone https://github.com/jonlukewatts/titans-miras.git
cd titans-miras
pip install -e .

With development dependencies:

pip install -e ".[dev]"

Quick Start

import torch
from titans_miras import (
    MACTransformer,
    TransformerConfig,
    NeuralMemoryConfig,
    MIRASConfig,
    MIRASArchitecture,
)

# Configure the model
config = TransformerConfig(
    num_tokens=256,           # Vocabulary size (256 for byte-level)
    dim=512,                  # Model dimension
    depth=6,                  # Number of transformer layers
    segment_len=64,           # Segment length for memory operations
    neural_mem=NeuralMemoryConfig(
        dim=512,
        heads=8,
        depth=2,              # Depth of memory MLP
    ),
    miras=MIRASConfig(
        architecture=MIRASArchitecture.DEFAULT,  # or YAAD, MONETA, MEMORA
    ),
)

# Create model
model = MACTransformer(config=config)

# Forward pass
x = torch.randint(0, 256, (1, 512))  # (batch, seq_len)
logits = model(x)                     # (batch, seq_len, vocab_size)

# Training with loss
loss = model(x, return_loss=True)
loss.backward()

Using Different MIRAS Variants

from titans_miras import MIRASConfig, MIRASArchitecture

# YAAD: Robust to outliers
yaad_config = MIRASConfig(
    architecture=MIRASArchitecture.YAAD,
    yaad_delta=1.0,  # Huber loss threshold
)

# MONETA: Generalized p-norms
moneta_config = MIRASConfig(
    architecture=MIRASArchitecture.MONETA,
    moneta_bias_p=1.5,
    moneta_gate_p=1.5,
)

# MEMORA: Probability map constraints
memora_config = MIRASConfig(
    architecture=MIRASArchitecture.MEMORA,
    memora_temperature=1.0,
)

Architecture

┌─────────────────────────────────────────────────────────┐
│                  MACTransformer              │
├─────────────────────────────────────────────────────────┤
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐  │
│  │   Token     │    │   Neural    │    │  Segmented  │  │
│  │  Embedding  │───▶│   Memory    │───▶│  Attention  │  │
│  └─────────────┘    └─────────────┘    └─────────────┘  │
│                            │                  │          │
│                     ┌──────▼──────┐          │          │
│                     │   Memory    │          │          │
│                     │    MLP      │◀─────────┘          │
│                     │  (stores &  │                     │
│                     │  retrieves) │                     │
│                     └─────────────┘                     │
└─────────────────────────────────────────────────────────┘

The Neural Memory module stores information in the weights of a small MLP, which are updated online during the forward pass using gradient-based learning. This allows the model to "memorize" important information and retrieve it later.

Einops Notation

The codebase uses einops for tensor operations. Here are the dimension conventions:

Symbol Meaning
b batch
h heads
bh batch and heads (combined)
n sequence length
d feature dimension
c intra-chunk position
w number of memory network weight parameters
o momentum orders
u key/value updates per token

Example usage in the codebase:

# Split into heads
x = rearrange(x, 'b n (h d) -> b h n d', h=num_heads)

# Compute attention
attn = einsum('b h n d, b h m d -> b h n m', q, k)

# Merge heads back
x = rearrange(x, 'b h n d -> b n (h d)')

Configuration Reference

TransformerConfig

Parameter Type Description
num_tokens int Vocabulary size
dim int Model hidden dimension
depth int Number of transformer layers
segment_len int Segment length for memory operations
heads int Number of attention heads (default: 8)
ff_mult int Feed-forward dimension multiplier (default: 4)
num_longterm_mem_tokens int Long-term memory tokens (default: 0)
num_persist_mem_tokens int Persistent memory tokens (default: 0)

NeuralMemoryConfig

Parameter Type Description
dim int Neural memory dimension
heads int Number of memory heads (default: 1)
depth int Depth of memory MLP (default: 2)
chunk_size int Chunk size for operations (default: 1)
momentum bool Use momentum in updates (default: True)
max_lr float Max learning rate for memory (default: 0.1)

See titans_miras/config.py for the complete configuration reference.

Training

For training models, use the unified training script:

python scripts/train.py --config path/to/config.yaml

See scripts/train.py for available options.

Experiments

The experiments/ directory contains scripts to reproduce paper results. Use the unified entry point:

# Compare MIRAS variants
python scripts/run_experiments.py miras-variants --scale small

# Memory depth ablation
python scripts/run_experiments.py memory-depth --scale small

# Language modeling (enwik8)
python scripts/run_experiments.py train-enwik8 --scale small

# Long-context evaluation
python scripts/run_experiments.py babilong --scale small

You can also run experiments directly:

python experiments/01_miras_variants/run_comparison.py --scale small

See experiments/README.md for detailed instructions.

Research Context

Titans: Learning to Memorize at Test Time

The Transformer architecture revolutionized sequence modeling with attention, but computational cost increases quadratically with sequence length. Titans addresses this by introducing a neural long-term memory module that:

  • Uses a deep neural network (MLP) as memory, providing higher expressive power than fixed-size vectors/matrices
  • Learns to recognize and retain important relationships across extremely long sequences
  • Employs a "surprise metric" (gradient magnitude) to prioritize memorable information
  • Incorporates momentum and adaptive forgetting for stable long-term memory

MIRAS: A Unified Framework

MIRAS (Memory-Informed Retrieval and Storage) provides a theoretical blueprint showing that major sequence modeling architectures are essentially associative memory modules. Key insights:

  • Transformers, RNNs, and SSMs can be viewed through the lens of associative memory
  • Different loss functions (MSE, Huber, p-norms) lead to different memory properties
  • Retention gates act as regularizers balancing old vs. new information

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

python_titans_miras-0.0.1.tar.gz (48.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

python_titans_miras-0.0.1-py3-none-any.whl (54.3 kB view details)

Uploaded Python 3

File details

Details for the file python_titans_miras-0.0.1.tar.gz.

File metadata

  • Download URL: python_titans_miras-0.0.1.tar.gz
  • Upload date:
  • Size: 48.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.2.0 CPython/3.11.14

File hashes

Hashes for python_titans_miras-0.0.1.tar.gz
Algorithm Hash digest
SHA256 51dea8892f096352cbcef49fabe6a42fd45896901443982bd0db3307e0148011
MD5 d870874e53bc5d12677e0e2fb7ed3c1d
BLAKE2b-256 e2daaef40a4e22f979e2f74f212546007cb14431edb170756b8f0fc73ca37411

See more details on using hashes here.

File details

Details for the file python_titans_miras-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for python_titans_miras-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 3620421cc57401f799a4168033458d54c3d47d92d27897f84b8bf252dda88898
MD5 a1b51a5575854027f86dbcceecd52176
BLAKE2b-256 dc8caaa85b5e0da133c7b6b7808d75528ecd9db6df2eec52c2671cf58ac2605a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page