Continual learning for PyTorch models. Wrap once. Train forever.
Project description
clearn
Wrap once. Train forever.
Continual learning for PyTorch models.
Prevent catastrophic forgetting with one line of code.
When you fine-tune a neural network on new data, it catastrophically forgets what it learned before. clearn fixes this. Wrap any PyTorch model, train on sequential tasks, and your model remembers everything.
import clearn
model = clearn.wrap(your_model, strategy="ewc")
model.fit(task1_loader, optimizer, task_id="q1_fraud")
model.fit(task2_loader, optimizer, task_id="q2_fraud")
print(model.diff())
RetentionReport
├── q1_fraud: 94.2% retained (-5.8%)
├── q2_fraud: 100.0% (current task)
├── plasticity_score: 0.87
├── stability_score: 0.94
└── recommendation: "stable — no action needed"
Installation
pip install clearn-ai
For HuggingFace integration:
pip install clearn-ai[hf]
Quickstart
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import clearn
# 1. Your PyTorch model
model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))
# 2. Wrap it — one line
cl_model = clearn.wrap(model, strategy="ewc")
# 3. Train on sequential tasks
for i, task_data in enumerate(sequential_tasks):
loader = DataLoader(task_data, batch_size=64)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
metrics = cl_model.fit(loader, optimizer, task_id=f"task_{i}")
print(f"Task {i}: loss={metrics.final_loss:.4f}, acc={metrics.final_accuracy:.2%}")
# 4. See what was retained
print(cl_model.diff())
That's it. Four steps. Your model now remembers.
Why clearn?
| Problem | Without clearn | With clearn |
|---|---|---|
| Train on Task 2 | Task 1 accuracy: 8% | Task 1 accuracy: 94% |
| Train on 20 tasks | First task: destroyed | First task: preserved |
| Debug forgetting | Print loss, guess | model.diff() tells you exactly |
Strategies
clearn ships five strategies:
EWC (Elastic Weight Consolidation)
Regularization-based. Identifies which weights matter most via the Fisher Information Matrix, then protects them during future training. No need to store past data.
model = clearn.wrap(net, strategy="ewc", lambda_=5000)
| Parameter | Default | Description |
|---|---|---|
lambda_ |
5000 |
Regularization strength. Higher = less forgetting, less plasticity |
n_fisher_samples |
200 |
Samples used to estimate weight importance |
SI (Synaptic Intelligence)
Online importance estimation. Tracks per-parameter contribution to loss reduction during training, then penalizes changes to important weights. No separate computation pass needed — importance is accumulated during training.
model = clearn.wrap(net, strategy="si", c=1.0)
| Parameter | Default | Description |
|---|---|---|
c |
1.0 |
Regularization strength (analogous to EWC's lambda) |
epsilon |
1e-3 |
Numerical stability constant |
DER++ (Dark Experience Replay)
Replay-based. Stores a small buffer of past examples and replays them during training, matching original logits via KL divergence with temperature scaling. Best general-purpose performance.
model = clearn.wrap(net, strategy="der", buffer_size=500)
| Parameter | Default | Description |
|---|---|---|
buffer_size |
200 |
Number of past samples to store |
alpha |
0.1 |
Weight for cross-entropy replay loss |
beta |
0.5 |
Weight for KL divergence logit-matching loss |
temperature |
2.0 |
Temperature for KL divergence softmax |
buffer_device |
"cpu" |
Device to store buffer on ("cuda" avoids transfers) |
GEM (Gradient Episodic Memory)
Constraint-based. Stores episodic memories from past tasks and projects gradients to avoid increasing loss on any previous task. Uses the efficient A-GEM variant.
model = clearn.wrap(net, strategy="gem", memory_size=256)
| Parameter | Default | Description |
|---|---|---|
memory_size |
256 |
Samples to store per task |
LoRA-EWC (Parameter-Efficient Continual Learning)
Combines LoRA adapters (via peft) with EWC regularization. Only the low-rank adapter weights are trained and protected — the base model stays frozen. Ideal for LLMs.
# Requires: pip install clearn-ai[hf]
model = clearn.from_pretrained("bert-base-uncased", strategy="lora-ewc", lora_r=8)
| Parameter | Default | Description |
|---|---|---|
lora_r |
8 |
LoRA rank (lower = more efficient) |
lora_alpha |
16 |
LoRA alpha scaling |
lambda_ |
5000 |
EWC regularization on LoRA weights |
Which strategy should I use?
Using a large language model?
├── Yes → LoRA-EWC (parameter-efficient + forgetting protection)
└── No → Can you store past data?
├── Yes → DER++ (best retention)
└── No → Do you need online tracking?
├── Yes → SI (no Fisher pass needed)
└── No → Want hard constraints?
├── Yes → GEM (gradient projection)
└── No → EWC (classic, reliable)
The diff() Report
The key feature. Like git diff, but for model knowledge.
report = model.diff()
print(report)
RetentionReport
├── task_a: 94.2% retained (-5.8%)
├── task_b: 88.1% retained (-11.9%)
├── task_c: 100.0% (current task)
├── plasticity_score: 0.91
├── stability_score: 0.91
└── recommendation: "stable — no action needed"
The report gives you:
- Per-task retention — exactly how much each task was preserved
- Plasticity score — how well the latest task was learned
- Stability score — average retention across all past tasks
- Recommendation — actionable advice ("increase lambda", "try DER++", etc.)
Training Metrics
Every fit() call returns detailed metrics:
metrics = model.fit(loader, optimizer, task_id="q1", epochs=5)
print(metrics)
TrainingMetrics(task='q1')
├── epochs: 5
├── final_loss: 0.3421
├── final_accuracy: 91.20%
└── wall_time: 2.15s
Access per-epoch data: metrics.epoch_losses, metrics.epoch_accuracies.
Strategy Diagnostics
Inspect the internals of your strategy at any time:
diag = model.diagnostics()
# EWC example:
# {'strategy': 'ewc', 'lambda': 5000, 'consolidated': True,
# 'fisher_mean': 0.0023, 'fisher_max': 10000.0, 'current_penalty': 42.5, ...}
# DER++ example:
# {'strategy': 'der++', 'buffer_used': 200, 'buffer_utilization': 1.0,
# 'buffer_class_distribution': {0: 45, 1: 38, ...}, ...}
Callbacks
Hook into training with the callback system:
from clearn import ContinualCallback
class LogCallback(ContinualCallback):
def on_task_start(self, model, task_id):
print(f"Starting {task_id}")
def on_batch_end(self, model, loss):
pass # Log to wandb, etc.
def on_task_end(self, model, task_id, metrics):
print(f"Finished {task_id}: {metrics.final_accuracy:.2%}")
model.fit(loader, optimizer, callbacks=[LogCallback()])
Built-in: EarlyStoppingCallback(patience=50).
Gradient Clipping & Mixed Precision
# Gradient clipping
model.fit(loader, optimizer, grad_clip=1.0)
# Mixed precision (AMP) — requires CUDA
model.fit(loader, optimizer, use_amp=True)
# Both
model.fit(loader, optimizer, grad_clip=1.0, use_amp=True)
Save & Load
# Save full state (model + strategy + task history)
model.save("./checkpoints/my_model")
# Load it back — diff() works after load
model = clearn.load("./checkpoints/my_model", model=your_model)
print(model.diff()) # Retention report preserved
HuggingFace Integration
First-class support for HuggingFace Transformers.
# Load any HuggingFace model with continual learning
model = clearn.from_pretrained("bert-base-uncased", strategy="ewc", task="classification")
model = clearn.from_pretrained("gpt2", strategy="lora-ewc", task="causal-lm")
# Get the tokenizer too
model, tokenizer = clearn.from_pretrained(
"bert-base-uncased", strategy="ewc", return_tokenizer=True
)
# Supported tasks: classification, token-classification, causal-lm, seq2seq-lm
ContinualTrainer — drop-in replacement for HuggingFace Trainer:
from clearn.integrations.huggingface import ContinualTrainer
trainer = ContinualTrainer(
model=cl_model,
args=training_args,
train_dataset=dataset,
task_id="sentiment_v1",
)
trainer.train() # Automatically applies forgetting protection
Push to HuggingFace Hub:
model.push_to_hub("your-username/my-continual-model")
API Reference
import clearn
# Wrap any PyTorch model
model = clearn.wrap(model, strategy="ewc", **kwargs)
# Train on a task (returns TrainingMetrics)
metrics = model.fit(dataloader, optimizer, epochs=1, task_id=None,
loss_fn=None, grad_clip=None, callbacks=None, use_amp=False)
# Get retention report
report = model.diff()
# Get strategy diagnostics
diag = model.diagnostics()
# Save / Load (diff() works after load)
model.save("path/to/checkpoint")
model = clearn.load("path/to/checkpoint", model=your_model)
# HuggingFace (requires clearn-ai[hf])
model = clearn.from_pretrained("bert-base-uncased", strategy="ewc", task="classification")
model, tokenizer = clearn.from_pretrained("gpt2", strategy="lora-ewc",
task="causal-lm", return_tokenizer=True)
model.push_to_hub("user/model-name")
Benchmark: CIFAR-100 Sequential
Split CIFAR-100 into 20 tasks. Train a ResNet-18 on each. Track Task 1 accuracy.
| Method | Task 1 Accuracy (after 20 tasks) |
|---|---|
| Baseline (SGD) | ~8% |
| clearn EWC | ~82% |
| clearn DER++ | ~88% |
Run the benchmark yourself:
Project Structure
clearn/
├── clearn/
│ ├── core.py # ContinualModel — the main wrapper
│ ├── strategies/
│ │ ├── base.py # Abstract strategy interface
│ │ ├── ewc.py # Elastic Weight Consolidation
│ │ ├── si.py # Synaptic Intelligence
│ │ ├── der.py # Dark Experience Replay++
│ │ ├── gem.py # Gradient Episodic Memory (A-GEM)
│ │ └── lora_ewc.py # LoRA + EWC hybrid
│ ├── metrics.py # RetentionReport, TrainingMetrics, diff() logic
│ ├── callbacks.py # ContinualCallback, EarlyStoppingCallback
│ └── integrations/
│ └── huggingface.py # from_pretrained(), ContinualTrainer, push_to_hub
├── tests/ # 114 tests, all passing
├── examples/ # Runnable demo scripts
└── benchmarks/ # CIFAR-100 notebook
Contributing
git clone https://github.com/itisrmk/clearn.git
cd clearn
pip install -e ".[dev]"
pytest tests/ -v
License
MIT
Built by Rahul Kashyap
Continual learning infrastructure for production ML
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file clearn_ai-0.3.0.tar.gz.
File metadata
- Download URL: clearn_ai-0.3.0.tar.gz
- Upload date:
- Size: 65.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7e92225004047304eed17b5ba10c73799fe7f4b001920c580a88c6bc9d087f87
|
|
| MD5 |
93b61c878d84043dd13a0c882915b860
|
|
| BLAKE2b-256 |
7794616a1d8538c46ff9a577dea37c6f08a10961c1818fb9f4c039195997e3e7
|
File details
Details for the file clearn_ai-0.3.0-py3-none-any.whl.
File metadata
- Download URL: clearn_ai-0.3.0-py3-none-any.whl
- Upload date:
- Size: 33.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
99adcaea12bcb303ad2140f58ec7896697b859c1581f853bcbbede2401abed6f
|
|
| MD5 |
519e9b09a5128517e52cd14e7b14c3c4
|
|
| BLAKE2b-256 |
6e52ee261f460c8d3acdccbcbade028526d3b4bfa79b992482742f9b91e52777
|