Skip to main content

Continual Learning SLM with TTT-E2E and Sparse Memory Finetuning

Project description

Continual Learning SLM

A continual learning system that enables Small Language Models to learn from new documents in real-time by updating neural weights at inference time, without catastrophic forgetting.

How It Works

This project implements three distinct continual learning strategies on top of Qwen2.5-1.5B:

Strategy 1: TTT-E2E (Test-Time Training End-to-End)

The core approach. The final 25% of transformer layers (layers 21-28) are modified with a Dual-MLP architecture: each layer gets a frozen MLP (preserving original intelligence) and a trainable MLP (absorbing new knowledge). When you feed the model a document, it runs mini-batch gradient descent to write knowledge directly into the trainable weights.

Sparse Memory Finetuning (TF-IDF gating) protects general-purpose neurons from being overwritten. Neurons that activate broadly across a calibration corpus are masked during gradient updates, while neurons specialized to the new content receive full gradients.

The alpha blending parameter controls the mix: frozen_out + (1-alpha) * trainable_out. Alpha starts at 0.95 and decays toward 0.5 as documents are learned, gradually increasing the influence of newly learned knowledge.

Strategy 2: JitRL MVP (Just-in-Time Retrieval Learning - Minimum Viable)

A lightweight retrieval-augmented approach. Instead of modifying weights, it indexes document chunks using TF-IDF and retrieves relevant passages at query time. Retrieved chunks are prepended as context, and a logit biaser nudges the model toward tokens that appear in the retrieved content.

Tradeoff: No weight modification means no forgetting risk, but knowledge is limited to what fits in the context window. Very fast (~0.002s to learn, ~2.4s to generate).

Strategy 3: JitRL Full (Reward-Guided Logit Modulation)

A more sophisticated retrieval approach. Documents are encoded into hidden-state embeddings and stored in a knowledge store. At query time, the system retrieves relevant knowledge embeddings, computes a reward signal via cosine similarity, and modulates the model's logit distribution to favor knowledge-aligned tokens.

Tradeoff: More expressive than MVP but slower (~0.04s to learn, ~33s to generate). Currently less accurate than MVP; needs hyperparameter tuning.

Requirements

  • Python 3.11+
  • NVIDIA GPU with 24GB+ VRAM (tested on A10)
  • CUDA toolkit

Installation

git clone https://github.com/jasperan/continual-learning.git
cd continual-learning
pip install -e ".[dev]"

This installs the package in editable mode with all dependencies (PyTorch, Transformers, JAX, Rich, Questionary, scikit-learn, etc.) plus dev tools (pytest).

Quick Start

continual-learning

This launches the interactive CLI. The model downloads automatically (~3GB) on first use.

Typical Workflow

  1. Select Chat with Model or Ask a Question - the model loads and injects DualMLP automatically
  2. Select Learn from Document - point it at a .txt, .md, or .jsonl file
  3. Select Chat with Model again - ask questions about what it just learned
  4. Select Run Benchmarks - measure accuracy and forgetting ratio against SQuAD holdout data
  5. Select Save Checkpoint - persist the learned state for later

Using Each Learning Strategy

TTT-E2E: Weight-Based Learning

Feed documents directly into the model's weights via the CLI:

CLI Option What It Does
Learn from Document Runs TTT-E2E on a single file. Tokenizes the text, splits into mini-batches of 32 tokens, and performs gradient descent. Shows per-batch loss and token count as it learns.
Learn from Directory Batch-learns all .txt, .md, and .jsonl files in a directory sequentially.

After learning, the trainable MLP weights are updated and alpha is decayed. The model's responses immediately reflect the new knowledge.

JitRL MVP: Fast Retrieval + Logit Biasing

CLI Option What It Does
JitRL MVP (Learn Doc) Indexes a document by chunking it and building a TF-IDF index. Then prompts you with a question - retrieves the top-3 most relevant chunks, prepends them as context, and applies logit biasing toward tokens found in the retrieved chunks.

The MVP engine does not modify model weights. You can learn multiple documents and they accumulate in the TF-IDF index.

JitRL Full: Knowledge Store + Reward Modulation

CLI Option What It Does
JitRL Full (Learn Doc) Encodes a document through the full model, captures the last hidden-state embeddings, and stores them in a knowledge store. At query time, it retrieves the closest knowledge embeddings via cosine similarity, computes a reward vector, and modulates the output logits through the model's language model head.

Comparing All Three Strategies

CLI Option What It Does
Compare All Engines A/B benchmarks across JitRL MVP and JitRL Full on the same document and QA pairs. Provide a document path, then enter question/answer pairs. The harness feeds the same data to each engine and reports accuracy, learn time, eval time, and tokens learned in a comparison table.

Evaluation and Benchmarks

CLI Option What It Does
Run Benchmarks Loads 50 items from SQuAD 2.0 validation set and evaluates the model's QA accuracy. Checks if the expected answer substring appears in the model's generated response. Reports accuracy and forgetting ratio if a baseline exists.
View Forgetting Metrics Shows catastrophic forgetting indicators. Forgetting ratio = (before - after) / before. A value of 0 means no forgetting, negative means the model improved. Target: < 0.15.
Model Info Displays architecture details: total/modified/frozen layers, total/trainable parameter counts, current alpha value, and whether TF-IDF gates are calibrated.
Learning History Shows a table of all documents learned in the current session: file name, token count, final loss, and timestamp.

Checkpointing

Checkpoints save only the trainable MLP weights and TF-IDF gate statistics (~50-100MB), not the full 3GB model.

CLI Option What It Does
Save Checkpoint Saves trainable MLP state dicts, TF-IDF gate stats (IDF scores, document frequencies), alpha values, learning history, and config to a named subdirectory under checkpoints/.
Load Checkpoint Presents a selection menu of saved checkpoints. Restores trainable weights, TF-IDF calibration, alpha values, and learning history.
List Checkpoints Shows all saved checkpoint names.

Configuration

Default settings are in configs/default.yaml. The CLI's Configure option lets you view and edit settings at runtime (changes persist to the YAML file).

model:
  name: "Qwen/Qwen2.5-1.5B"
  modified_layers_start: 21    # First layer to inject DualMLP
  modified_layers_end: 28      # Last layer (exclusive)
  device: "auto"               # "auto", "cuda", or "cpu"

ttt:
  learning_rate: 1.0e-5        # Adam learning rate for TTT-E2E
  mini_batch_size: 32          # Tokens per mini-batch
  gradient_steps: 1            # Gradient steps per mini-batch
  max_tokens_per_document: 4096  # Truncation limit

alpha:
  initial: 1.0                 # Starting blend weight (1.0 = fully frozen)
  decay_rate: 0.01             # Alpha decrease per learning step
  min_value: 0.3               # Floor for alpha decay

tfidf_gate:
  threshold: 0.3               # TF-IDF score below which gradients are masked
  calibration_samples: 2000    # Number of samples for IDF calibration

Running Tests

# All 101 tests (~10 seconds, no GPU needed)
python -m pytest tests/

# By component
python -m pytest tests/test_model/          # DualMLP, modified Qwen, TF-IDF gate
python -m pytest tests/test_training/       # TTT-E2E engine
python -m pytest tests/test_jitrl/          # JitRL MVP, Full, comparison harness
python -m pytest tests/test_evaluation/     # Benchmarks, forgetting metrics
python -m pytest tests/test_data/           # SQuAD pipeline, Oracle docs
python -m pytest tests/test_checkpointing/  # Checkpoint save/load
python -m pytest tests/test_cli/            # CLI menu and handlers
python -m pytest tests/test_config.py       # YAML config loading

# Single test by name
python -m pytest tests/test_model/test_dual_mlp.py -k "test_forward"

GPU Validation Scripts

End-to-end validation on real GPU hardware (requires A10 or equivalent with 24GB VRAM):

# Validates 4 milestones sequentially:
#   1. Architecture: Loads Qwen2.5-1.5B + DualMLP injection, verifies 7 modified layers
#   2. TTT-E2E: Learns a test document, verifies weights change and loss is recorded
#   3. Sparse Memory: Calibrates TF-IDF gates, learns domain docs, measures forgetting ratio (<0.15)
#   4. Oracle Docs: Fetches live Oracle documentation, learns from it, tests Oracle-specific Q&A
python scripts/validate_gpu.py

# Compares JitRL MVP vs Full on identical Oracle AI Vector Search content:
#   - Tests each engine individually (learn time, generate time, response quality)
#   - Runs comparison harness with 3 QA items, reports accuracy/timing side by side
python scripts/validate_jitrl.py

Project Structure

src/continual_learning/
├── model/
│   ├── dual_mlp.py          # DualMLP: frozen + trainable MLPs with alpha blending
│   ├── modified_qwen.py     # Loads Qwen2.5-1.5B and injects DualMLP into layers 21-28
│   └── tfidf_gate.py        # TF-IDF gate: calibrates IDF scores, computes gradient masks
├── training/
│   ├── ttt_engine.py        # TTT-E2E: mini-batch gradient descent with TF-IDF masking
│   └── calibration.py       # Collects activations and calibrates TF-IDF gates
├── evaluation/
│   ├── benchmarks.py        # QA accuracy evaluation on holdout sets
│   └── forgetting_metrics.py # Catastrophic forgetting ratio computation
├── data/
│   ├── streaming_qa.py      # SQuAD 2.0 loader with learn/holdout splits
│   └── oracle_docs.py       # Fetches, parses, and chunks Oracle documentation
├── jitrl/
│   ├── base.py              # Abstract BaseJitRLEngine interface (learn/generate/clear)
│   ├── mvp/
│   │   ├── engine.py        # JitRL MVP: TF-IDF retrieval + context prepending + logit bias
│   │   ├── retriever.py     # TF-IDF document retriever with chunking
│   │   └── logit_bias.py    # Computes per-token bias from retrieved chunks
│   ├── full/
│   │   ├── engine.py        # JitRL Full: hidden-state knowledge store + reward modulation
│   │   ├── knowledge_store.py # Stores and retrieves document embeddings by cosine similarity
│   │   └── reward.py        # Computes reward vectors and modulates logits
│   └── comparison.py        # A/B harness: runs identical benchmarks across engines
├── checkpointing/
│   └── manager.py           # Saves/loads trainable weights, TF-IDF stats, alpha, metadata
├── cli/
│   └── main.py              # Interactive menu (Questionary + Rich) with all handlers
└── config.py                # YAML config loader with defaults merge

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

continual_learning_slm-0.2.0.tar.gz (30.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

continual_learning_slm-0.2.0-py3-none-any.whl (32.9 kB view details)

Uploaded Python 3

File details

Details for the file continual_learning_slm-0.2.0.tar.gz.

File metadata

  • Download URL: continual_learning_slm-0.2.0.tar.gz
  • Upload date:
  • Size: 30.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.2

File hashes

Hashes for continual_learning_slm-0.2.0.tar.gz
Algorithm Hash digest
SHA256 087a3a17b42be6da0838ffc53b0b920c3f45f31385862364ee751f4c1d540df6
MD5 9ef440311e5c4002ceaccc021b5c612a
BLAKE2b-256 e811db350a50373d65a4105dd61150ad38a56de968e1aad4bcece938226a44a9

See more details on using hashes here.

File details

Details for the file continual_learning_slm-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for continual_learning_slm-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 7de3affe657ea7c3db9c26deaf8dc8453d69f028cdaf28d1f482000333a2ab67
MD5 59869d89aad2dadd2ba2f4ff119a2655
BLAKE2b-256 b5ed0bda742707903bd016269f9dcaebaf5579cd958b2a55af588d0415dbb603

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page