Skip to main content

Continual learning for LLM fine-tuning via ARIA mechanisms (PlasticityGatedMLP, SPC, Task Adapters)

Project description

aria-trl: Continual Learning for LLM Fine-tuning

Python 3.10+ PyTorch 2.0+ License: MIT

aria-trl brings ARIA's continual learning mechanisms to Hugging Face's TRL library, enabling fine-tuning of large language models on sequential tasks without catastrophic forgetting.

Overview

Fine-tuning LLMs sequentially on multiple tasks typically leads to catastrophic forgetting — the model forgets earlier tasks while learning new ones. aria-trl prevents this through three ARIA mechanisms:

  1. PlasticityGatedMLP: Dual fast/slow pathways in FFN layers

    • Fast pathway: volatile, learns task-specific patterns
    • Slow pathway: stable, retains task-generic knowledge
    • Learned gate routes computation per token
  2. Slow-Pathway Consolidation (SPC): Fisher-weighted regularization

    • Estimates diagonal Fisher Information after each task
    • Protects consolidated knowledge from being overwritten
    • 50% fewer parameters than standard EWC
  3. Task-Specific Adapters: Lightweight LoRA-like modules

    • One residual adapter per task
    • Frozen after training to prevent overwriting
    • Minimal parameter overhead

Features

  • Drop-in SFTTrainer subclass — minimal code changes
  • Asymmetric learning rates — slow pathway learns slower
  • Gradient dampening — slow grads scaled by (1−π̄)
  • Works with any HF model — LLaMA, Mistral, DistilGPT2, etc.
  • Full checkpoint support — save/load with consolidation state
  • Production-ready — comprehensive error handling

Installation

pip install aria-trl

Or install from source:

git clone https://github.com/rsd-darshan/aria-trl.git
cd aria-trl
pip install -e .

Quick Start

from aria_trl import ContinualSFTTrainer, ARIAConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments

# Load model
model = AutoModelForSequenceClassification.from_pretrained("distilgpt2", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# Configure ARIA
aria_config = ARIAConfig(
    plasticity_lambda=0.01,      # bimodal gate specialization
    spc_lambda=100.0,            # Fisher consolidation strength
    adapter_dim=64,              # task adapter bottleneck
    slow_lr_ratio=0.5,           # asymmetric LR: slow=0.5x fast
)

# Training arguments
args = TrainingArguments(
    output_dir="./checkpoints",
    learning_rate=2e-4,
    num_train_epochs=3,
    per_device_train_batch_size=8,
)

# Create trainer
trainer = ContinualSFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    train_dataset=task1_train,
    eval_dataset=task1_eval,
    aria_config=aria_config,
)

# Train on task 1
trainer.train()

# Consolidate before next task (Fisher estimation)
trainer.consolidate_task(task_id=0)

# Train on task 2
trainer.add_task(task_id=1)
trainer.train_dataset = task2_train
trainer.eval_dataset = task2_eval
trainer.train()

# Consolidate again
trainer.consolidate_task(task_id=1)

# Evaluate on all tasks to measure forgetting
metrics = trainer.evaluate_all_tasks(all_tasks)

Core Components

ARIAConfig

Configuration for ARIA continual learning:

from aria_trl import ARIAConfig

config = ARIAConfig(
    plasticity_lambda=0.01,          # Bimodal specialization loss weight
    spc_lambda=100.0,                # Fisher regularization strength
    adapter_dim=64,                  # Task adapter bottleneck dimension
    slow_lr_ratio=0.5,               # Slow pathway LR multiplier
    warmup_steps=500,                # Before plasticity loss activates
    consolidation_steps_per_task=None # Fisher estimation steps (None=all)
)

PlasticityGatedMLP

Replaces transformer FFN layers automatically. Routes per-token computation:

output = π * fast_pathway(x) + (1−π) * slow_pathway(x)

where π ∈ (0,1) is learned and regularized to specialize (0 or 1).

TaskFastAdapter

Per-task residual adapter (LoRA-like):

h_new = h + adapter(h)

Bottleneck design: h → compress → ReLU → expand, zero-initialized expansion.

FisherConsolidator

Estimates diagonal Fisher Information on validation data after each task:

consolidator = FisherConsolidator(model, device)
consolidator.consolidate(task_id, eval_loader)
loss = consolidator.compute_spc_loss(global_step)

Metrics

Compute continual learning metrics with:

from aria_trl.utils import compute_continual_metrics

metrics = compute_continual_metrics(task_accuracies)
# Returns: avg_accuracy, forgetting, forward_transfer
  • Average Accuracy: Mean accuracy on all tasks at end
  • Forgetting: How much old tasks degrade (lower is better)
  • Forward Transfer: How much new tasks benefit from old tasks

Example: DistilGPT2 on Sequential Tasks

See examples/distilgpt2_example.py for a complete working example:

python examples/distilgpt2_example.py

Trains on 3 sequential tasks (sentiment, toxicity, spam) and prints metrics.

How ARIA Prevents Forgetting

The Problem

Standard fine-tuning on sequential tasks overwrites weights learned on earlier tasks:

Task 1:  [update all weights]
Task 2:  [update all weights again] ← Task 1 weights overwritten
Task 3:  [update all weights again] ← Task 1 & 2 weights overwritten
Eval:    [accuracy on Task 1 ↓↓↓]   ← Catastrophic forgetting!

The Solution

ARIA separates fast (volatile) and slow (stable) pathways:

Fast pathway:  [updates freely, learns task-specific patterns]
Slow pathway:  [updates slowly, consolidated by Fisher regularization]
Gate π:        [routes each token: fast for new patterns, slow for stability]

Result: Old tasks remain in slow pathway, new tasks learn in fast pathway

Design Philosophy

aria-trl's design prioritizes:

  1. Compatibility: Subclasses SFTTrainer, works with any HF model
  2. Simplicity: Three mechanisms, no complex multi-model ensembles
  3. Efficiency: Task adapters add <1% parameters, Fisher diagonal reduces memory
  4. Interpretability: Gate π shows per-token plasticity, easy to debug

Papers & Motivation

Limitations & Future Work

  • Slow pathway at 0.5× LR: Conservative; tuning improves CIFAR-10 accuracy
  • Fisher diagonal only: Full Fisher or Kronecker-factored approximations are future work
  • DistilGPT2 tested: Larger models (7B+) need memory optimization
  • Binary classification: Multi-class and language modeling tasks need validation

Contributing

Contributions welcome! Please:

  1. Fork the repo
  2. Create a feature branch (git checkout -b feature/my-feature)
  3. Commit changes (no AI attribution required)
  4. Push to remote and open a PR

License

MIT License — see LICENSE for details.

Citation

If you use aria-trl in your research, please cite:

@software{poudel2026aria_trl,
  title   = {aria-trl: Continual Learning for LLM Fine-tuning},
  author  = {Poudel, Darshan},
  year    = {2026},
  url     = {https://github.com/rsd-darshan/aria-trl}
}

Contact


aria-trl — continual learning, without the catastrophe.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aria_trl-1.0.0.tar.gz (19.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

aria_trl-1.0.0-py3-none-any.whl (15.1 kB view details)

Uploaded Python 3

File details

Details for the file aria_trl-1.0.0.tar.gz.

File metadata

  • Download URL: aria_trl-1.0.0.tar.gz
  • Upload date:
  • Size: 19.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for aria_trl-1.0.0.tar.gz
Algorithm Hash digest
SHA256 7903af77c76ed5ced2a88bad19388d89a88ffc3f5c194060efa5afd0dd7c8eca
MD5 53fe5944ac6f6f89fc3f68dea4525fc9
BLAKE2b-256 8f65af147a62102128807bdfa2f99408b62e541f0ac97e58a72a6dd8da4afc72

See more details on using hashes here.

File details

Details for the file aria_trl-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: aria_trl-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 15.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.3

File hashes

Hashes for aria_trl-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8f02581b89a6c1ac0f8c8b4e12a969dfc083dae0c1dfab0df4a40d9dde854e02
MD5 b93d3852f5a38f86e23aa5671dfc1891
BLAKE2b-256 f0ee5f0ae687648e0998e522fe03cdd1b948c951941778deddec1206ec0be37b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page