Skip to main content

Meta-learning framework that supervises model training and automatically intervenes on problems

Project description

Self-Healing Meta-Trainer

A meta-learning framework that supervises model training and automatically intervenes when it detects problems like overfitting, gradient explosion, or catastrophic forgetting.

Installation

# Basic installation
pip install self-healing-trainer

# With PyTorch support
pip install self-healing-trainer[torch]

# With HuggingFace Trainer support
pip install self-healing-trainer[transformers]

# With live dashboard
pip install self-healing-trainer[dashboard]

# Everything
pip install self-healing-trainer[all]

Or install from source:

git clone https://github.com/self-healing-trainer/self-healing-trainer.git
cd self-healing-trainer
pip install -e .

What It Does

The meta-trainer learns HOW to train models by observing training trajectories. It then supervises any training session and takes corrective action:

Problem Detected Action Taken
Overfitting TRUE ROLLBACK - restores model weights
Underfitting Increase learning rate
Gradient Explosion Clip gradients, reduce LR
Catastrophic Forgetting Rollback + reduce LR
Training Plateau Adjust learning rate
NaN/Inf Loss Stop training
Oscillating Loss Reduce learning rate

Key Features

1. TRUE Rollback

Actually restores model weights from in-memory checkpoints - not just a signal.

2. Direct Optimizer Control

Directly modifies the optimizer's learning rate, not just recommendations.

3. Online Learning

Meta-trainer improves from each training run - learns from real data, not just synthetic.

4. Live Dashboard

Rich terminal dashboard showing losses, actions, and decisions in real-time.

5. Pip Installable

Install with pip install self-healing-trainer - no sys.path hacks needed.


Quick Start

1. Train the Meta-Trainer (one-time setup)

# CLI
meta-trainer train --output meta_trainer_model.json

# Or from Python
from meta_trainer import MetaTrainer, TrajectoryGenerator

generator = TrajectoryGenerator(seed=42)
trajectories = generator.generate_all_scenarios(variations_per_scenario=5)
meta = MetaTrainer()
meta.learn_from_trajectories(trajectories)
meta.save("meta_trainer_model.json")

2. Use in Your Training

Option A: HuggingFace Trainer (Recommended)

from callbacks import MetaTrainerCallback
from transformers import Trainer

callback = MetaTrainerCallback(
    meta_trainer_path="meta_trainer_model.json",
    enable_rollback=True,         # TRUE rollback with weight restoration
    enable_lr_adjust=True,        # Direct optimizer LR control
    enable_online_learning=True,  # Learn from this run
    enable_dashboard=True         # Live terminal dashboard
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[callback]
)
trainer.train()

Option B: PyTorch Training Loop

from meta_trainer import MetaTrainer, TrainingState, ActionType

meta = MetaTrainer.load("meta_trainer_model.json")

for step in range(total_steps):
    loss = train_step(...)
    val_loss = evaluate(...)

    state = TrainingState(
        step=step,
        train_loss=loss,
        val_loss=val_loss,
        train_loss_history=train_losses[-20:],
        val_loss_history=val_losses[-20:],
        learning_rate=lr,
        gradient_norm=grad_norm,
        best_val_loss=best_val_loss,
        steps_since_improvement=steps_no_improve
    )

    action = meta.decide(state)

    if action.action_type == ActionType.STOP:
        break
    elif action.action_type == ActionType.ROLLBACK:
        model.load_state_dict(checkpoints[action.rollback_to_step])
        lr *= 0.5  # Reduce LR after rollback
    elif action.action_type == ActionType.REDUCE_LR:
        lr *= 0.5
    elif action.action_type == ActionType.INCREASE_LR:
        lr *= 2.0
    elif action.action_type == ActionType.CLIP_GRADIENTS:
        torch.nn.utils.clip_grad_norm_(model.parameters(), action.clip_value)

CLI Commands

# Train a new meta-trainer
meta-trainer train --output meta_trainer_model.json --variations 5

# Test on scenarios
meta-trainer test --model meta_trainer_model.json

# Simulate a training scenario
meta-trainer simulate --scenario overfitting --model meta_trainer_model.json

# Launch dashboard demo
meta-trainer dashboard --model meta_trainer_model.json

Dashboard

The live terminal dashboard shows:

  • Loss curves (train + val) with sparklines
  • Learning rate history
  • Gradient norms
  • Meta-trainer actions taken
  • Real-time statistics
from dashboard import TerminalDashboard

dashboard = TerminalDashboard(title="My Training")
dashboard.start()

# In training loop:
dashboard.update(step=100, train_loss=0.5, val_loss=0.6, ...)
dashboard.log_action("rollback", "Overfitting detected")

dashboard.stop()

Or run the demo:

meta-trainer dashboard

API Reference

MetaTrainerCallback

MetaTrainerCallback(
    meta_trainer_path: str = None,      # Path to trained model
    check_every_n_steps: int = 10,      # How often to check
    verbose: bool = True,               # Print actions
    enable_rollback: bool = True,       # TRUE rollback (restores weights)
    enable_lr_adjust: bool = True,      # Direct optimizer control
    enable_early_stop: bool = True,     # Allow early stopping
    enable_online_learning: bool = True, # Learn from this run
    enable_dashboard: bool = False,     # Show live dashboard
    min_lr: float = 1e-7,              # Minimum learning rate
    max_lr: float = 1e-3,              # Maximum learning rate
    max_in_memory_checkpoints: int = 3, # Checkpoints to keep
    checkpoint_on_improvement: bool = True  # Auto-save on improvement
)

TrainingState

@dataclass
class TrainingState:
    step: int
    epoch: int
    total_steps: int
    train_loss: float
    val_loss: float
    train_loss_history: List[float]
    val_loss_history: List[float]
    learning_rate: float
    gradient_norm: float
    gradient_norm_history: List[float]
    best_val_loss: float
    best_checkpoint_step: int
    steps_since_improvement: int

TrainingAction

@dataclass
class TrainingAction:
    action_type: ActionType
    reasoning: str
    confidence: float
    new_lr: float
    rollback_to_step: int
    clip_value: float

ActionType

class ActionType(Enum):
    CONTINUE = "continue"
    STOP = "stop"
    ROLLBACK = "rollback"
    REDUCE_LR = "reduce_lr"
    INCREASE_LR = "increase_lr"
    CHECKPOINT = "checkpoint"
    CLIP_GRADIENTS = "clip_gradients"

File Structure

self-healing-trainer/
├── README.md
├── pyproject.toml              # Pip package config
├── meta_trainer_model.json     # Trained model (after setup)
│
├── meta_trainer/               # Core module
│   ├── __init__.py
│   ├── schema.py              # Data structures
│   ├── generator.py           # Trajectory generator
│   └── meta_trainer.py        # Main class
│
├── callbacks/                  # Framework integrations
│   ├── __init__.py
│   └── huggingface_callback.py  # HuggingFace Trainer callback
│
└── dashboard/                  # Live visualization
    ├── __init__.py
    └── terminal_dashboard.py  # Rich-based dashboard

Tested Scenarios

All stress tests pass:

  • Catastrophic forgetting
  • Severe overfitting
  • Gradient explosion
  • NaN/Inf loss
  • Loss oscillation
  • Underfitting
  • Perfect training
  • Mixed scenarios
  • Edge cases
  • Long sequences

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

self_healing_trainer-1.0.0.tar.gz (758.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

self_healing_trainer-1.0.0-py3-none-any.whl (791.0 kB view details)

Uploaded Python 3

File details

Details for the file self_healing_trainer-1.0.0.tar.gz.

File metadata

  • Download URL: self_healing_trainer-1.0.0.tar.gz
  • Upload date:
  • Size: 758.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.3

File hashes

Hashes for self_healing_trainer-1.0.0.tar.gz
Algorithm Hash digest
SHA256 51747b85f2bbfe001b57030f8c63d0e5fddd0441130bd5900b9fbccdaac2f859
MD5 722d797e00e131039f802466c1ecdc23
BLAKE2b-256 225343681e79e53fd81b713869e608e86db1873f3dbbfb03b1d32baf051738f3

See more details on using hashes here.

File details

Details for the file self_healing_trainer-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for self_healing_trainer-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 88778e7f5dab0479eb394347071b2de5deed2ac469fea3e4fce6c95bec79ab71
MD5 e1802e9c2f37dc350e4ca2ba67dabed2
BLAKE2b-256 c8d705866c10337bb3ae916d5cefe99d91c9c498b42799c02b36a761b7a2c23d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page