Continual learning for LLM fine-tuning via ARIA mechanisms (PlasticityGatedMLP, SPC, Task Adapters)
Project description
aria-trl: Continual Learning for LLM Fine-tuning
aria-trl brings ARIA's continual learning mechanisms to Hugging Face's TRL library, enabling fine-tuning of large language models on sequential tasks without catastrophic forgetting.
Overview
Fine-tuning LLMs sequentially on multiple tasks typically leads to catastrophic forgetting — the model forgets earlier tasks while learning new ones. aria-trl prevents this through three ARIA mechanisms:
-
PlasticityGatedMLP: Dual fast/slow pathways in FFN layers
- Fast pathway: volatile, learns task-specific patterns
- Slow pathway: stable, retains task-generic knowledge
- Learned gate routes computation per token
-
Slow-Pathway Consolidation (SPC): Fisher-weighted regularization
- Estimates diagonal Fisher Information after each task
- Protects consolidated knowledge from being overwritten
- 50% fewer parameters than standard EWC
-
Task-Specific Adapters: Lightweight LoRA-like modules
- One residual adapter per task
- Frozen after training to prevent overwriting
- Minimal parameter overhead
Features
- ✅ Drop-in SFTTrainer subclass — minimal code changes
- ✅ Asymmetric learning rates — slow pathway learns slower
- ✅ Gradient dampening — slow grads scaled by (1−π̄)
- ✅ Works with any HF model — LLaMA, Mistral, DistilGPT2, etc.
- ✅ Full checkpoint support — save/load with consolidation state
- ✅ Production-ready — comprehensive error handling
Installation
pip install aria-trl
Or install from source:
git clone https://github.com/rsd-darshan/aria-trl.git
cd aria-trl
pip install -e .
Quick Start
from aria_trl import ContinualSFTTrainer, ARIAConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments
# Load model
model = AutoModelForSequenceClassification.from_pretrained("distilgpt2", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
# Configure ARIA
aria_config = ARIAConfig(
plasticity_lambda=0.01, # bimodal gate specialization
spc_lambda=100.0, # Fisher consolidation strength
adapter_dim=64, # task adapter bottleneck
slow_lr_ratio=0.5, # asymmetric LR: slow=0.5x fast
)
# Training arguments
args = TrainingArguments(
output_dir="./checkpoints",
learning_rate=2e-4,
num_train_epochs=3,
per_device_train_batch_size=8,
)
# Create trainer
trainer = ContinualSFTTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=task1_train,
eval_dataset=task1_eval,
aria_config=aria_config,
)
# Train on task 1
trainer.train()
# Consolidate before next task (Fisher estimation)
trainer.consolidate_task(task_id=0)
# Train on task 2
trainer.add_task(task_id=1)
trainer.train_dataset = task2_train
trainer.eval_dataset = task2_eval
trainer.train()
# Consolidate again
trainer.consolidate_task(task_id=1)
# Evaluate on all tasks to measure forgetting
metrics = trainer.evaluate_all_tasks(all_tasks)
Core Components
ARIAConfig
Configuration for ARIA continual learning:
from aria_trl import ARIAConfig
config = ARIAConfig(
plasticity_lambda=0.01, # Bimodal specialization loss weight
spc_lambda=100.0, # Fisher regularization strength
adapter_dim=64, # Task adapter bottleneck dimension
slow_lr_ratio=0.5, # Slow pathway LR multiplier
warmup_steps=500, # Before plasticity loss activates
consolidation_steps_per_task=None # Fisher estimation steps (None=all)
)
PlasticityGatedMLP
Replaces transformer FFN layers automatically. Routes per-token computation:
output = π * fast_pathway(x) + (1−π) * slow_pathway(x)
where π ∈ (0,1) is learned and regularized to specialize (0 or 1).
TaskFastAdapter
Per-task residual adapter (LoRA-like):
h_new = h + adapter(h)
Bottleneck design: h → compress → ReLU → expand, zero-initialized expansion.
FisherConsolidator
Estimates diagonal Fisher Information on validation data after each task:
consolidator = FisherConsolidator(model, device)
consolidator.consolidate(task_id, eval_loader)
loss = consolidator.compute_spc_loss(global_step)
Metrics
Compute continual learning metrics with:
from aria_trl.utils import compute_continual_metrics
metrics = compute_continual_metrics(task_accuracies)
# Returns: avg_accuracy, forgetting, forward_transfer
- Average Accuracy: Mean accuracy on all tasks at end
- Forgetting: How much old tasks degrade (lower is better)
- Forward Transfer: How much new tasks benefit from old tasks
Example: DistilGPT2 on Sequential Tasks
See examples/distilgpt2_example.py for a complete working example:
python examples/distilgpt2_example.py
Trains on 3 sequential tasks (sentiment, toxicity, spam) and prints metrics.
How ARIA Prevents Forgetting
The Problem
Standard fine-tuning on sequential tasks overwrites weights learned on earlier tasks:
Task 1: [update all weights]
Task 2: [update all weights again] ← Task 1 weights overwritten
Task 3: [update all weights again] ← Task 1 & 2 weights overwritten
Eval: [accuracy on Task 1 ↓↓↓] ← Catastrophic forgetting!
The Solution
ARIA separates fast (volatile) and slow (stable) pathways:
Fast pathway: [updates freely, learns task-specific patterns]
Slow pathway: [updates slowly, consolidated by Fisher regularization]
Gate π: [routes each token: fast for new patterns, slow for stability]
Result: Old tasks remain in slow pathway, new tasks learn in fast pathway
Design Philosophy
aria-trl's design prioritizes:
- Compatibility: Subclasses SFTTrainer, works with any HF model
- Simplicity: Three mechanisms, no complex multi-model ensembles
- Efficiency: Task adapters add <1% parameters, Fisher diagonal reduces memory
- Interpretability: Gate π shows per-token plasticity, easy to debug
Papers & Motivation
- ARIA (PyTorch): "Adaptive Recurrent Intelligence Architecture" — core continual learning research, CNNs and task-incremental learning
- aria-trl (TRL): This package — brings ARIA mechanisms to LLM fine-tuning via Hugging Face
- NCG: "Novelty-triggered Capacity Growth" — reactive capacity expansion (predecessor)
- EWC: "Elastic Weight Consolidation" — standard baseline for catastrophic forgetting
Limitations & Future Work
- Slow pathway at 0.5× LR: Conservative; tuning improves CIFAR-10 accuracy
- Fisher diagonal only: Full Fisher or Kronecker-factored approximations are future work
- DistilGPT2 tested: Larger models (7B+) need memory optimization
- Binary classification: Multi-class and language modeling tasks need validation
Contributing
Contributions welcome! Please:
- Fork the repo
- Create a feature branch (
git checkout -b feature/my-feature) - Commit changes (no AI attribution required)
- Push to remote and open a PR
License
MIT License — see LICENSE for details.
Citation
If you use aria-trl in your research, please cite:
@software{poudel2026aria_trl,
title = {aria-trl: Continual Learning for LLM Fine-tuning},
author = {Poudel, Darshan},
year = {2026},
url = {https://github.com/rsd-darshan/aria-trl}
}
Contact
- GitHub: @rsd-darshan
- Email: poudeldarshan44@gmail.com
aria-trl — continual learning, without the catastrophe.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file aria_trl-1.0.0.tar.gz.
File metadata
- Download URL: aria_trl-1.0.0.tar.gz
- Upload date:
- Size: 19.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7903af77c76ed5ced2a88bad19388d89a88ffc3f5c194060efa5afd0dd7c8eca
|
|
| MD5 |
53fe5944ac6f6f89fc3f68dea4525fc9
|
|
| BLAKE2b-256 |
8f65af147a62102128807bdfa2f99408b62e541f0ac97e58a72a6dd8da4afc72
|
File details
Details for the file aria_trl-1.0.0-py3-none-any.whl.
File metadata
- Download URL: aria_trl-1.0.0-py3-none-any.whl
- Upload date:
- Size: 15.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8f02581b89a6c1ac0f8c8b4e12a969dfc083dae0c1dfab0df4a40d9dde854e02
|
|
| MD5 |
b93d3852f5a38f86e23aa5671dfc1891
|
|
| BLAKE2b-256 |
f0ee5f0ae687648e0998e522fe03cdd1b948c951941778deddec1206ec0be37b
|