Skip to main content

No project description provided

Project description

mentor : A pytorch training framework for lazy and impatient people

License Tests coverage Docs

A lightweight PyTorch training framework built around a single idea: a model should carry its own training history.

Mentee is a torch.nn.Module subclass that transparently records every epoch of training, validation metrics, software environment, and command-line invocation — all saved into a single .pt checkpoint. Resuming on a different machine, reporting on a run, or rolling back to the best epoch requires no extra bookkeeping.


Installation

pip install torch-mentor

Or from source:

git clone https://github.com/anguelos/mentor
pip install -e mentor/

Quick start

Option A — Built-in trainer (least code)

Assign a Classifier or Regressor trainer to self.trainer and only implement forward:

import torch.nn as nn
from mentor import Mentee, Classifier

class MyNet(Mentee):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__(num_classes=num_classes)
        self.fc = nn.Linear(784, num_classes)
        self.trainer = Classifier()   # supplies training_step + validation_step

    def forward(self, x):
        return self.fc(x.flatten(1))

Option B — Custom training step

Override training_step and optionally validation_step for full control:

import torch.nn as nn, torch.nn.functional as F
from mentor import Mentee

class MyNet(Mentee):
    def __init__(self, num_classes: int = 10) -> None:
        super().__init__(num_classes=num_classes)
        self.fc = nn.Linear(784, num_classes)

    def forward(self, x):
        return self.fc(x.flatten(1))

    def training_step(self, sample):
        x, y = sample
        logits = self(x.to(self.device))
        loss = F.cross_entropy(logits, y.to(self.device))
        acc = float(logits.argmax(1).eq(y.to(self.device)).float().mean())
        return loss, {"accuracy": acc, "loss": loss.item()}

    def validation_step(self, sample):
        x, y = sample
        logits = self(x.to(self.device))
        acc = float(logits.argmax(1).eq(y.to(self.device)).float().mean())
        return {"accuracy": acc}

Train, validate, save

from torch.utils.data import DataLoader

model = MyNet(num_classes=10)
model.to("cuda")
train_objs = model.create_train_objects(lr=1e-3)
optimizer, scheduler = train_objs["optimizer"], train_objs["lr_scheduler"]

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=32)

for epoch in range(20):
    train_metrics = model.train_epoch(train_loader, optimizer, lr_scheduler=scheduler,
                                      pseudo_batch_size=4, verbose=True)
    val_metrics   = model.validate_epoch(val_loader, verbose=True)
    print(model)                              # shows live training summary
    model.save("checkpoint.pt", optimizer=optimizer, lr_scheduler=scheduler)

Resume anywhere

model, optimizer, scheduler = MyNet.resume_training(
    "checkpoint.pt", model_class=MyNet, device="cuda", lr=1e-3
)
print(f"Resuming from epoch {model.current_epoch}")

Key features

Feature Description
Automatically resumable save() + resume_training() restore weights, optimizer state, and full history — one line to pick up where you left off, on any machine
Self-contained checkpoints Everything in one .pt file: weights, optimizer, LR scheduler, training & validation history, best weights, and inference state
Automatic TensorBoard Pass tensorboard_writer to train_epoch / validate_epoch and all metrics are logged with no extra code
Automatic provenance Git hash, Python & PyTorch versions, hostname, user, and sys.argv recorded automatically every epoch
Best-weights tracking Best checkpoint is updated whenever the principal validation metric improves; roll back with one call
Built-in trainers Classifier and Regressor supply loss, training_step, and validation_step via composition — only forward required
Gradient accumulation pseudo_batch_size accumulates gradients over N batches before each optimizer step
OOM tolerance memfail="ignore" skips samples that raise MemoryError and counts them in the epoch metrics

How mentor compares

Feature mentor Lightning HF Trainer fastai Ignite
Model is plain nn.Module ⚠️
You own the training loop 👍
Full metric history in checkpoint
One-call resume (weights + optimizer + history) ⚠️ ⚠️
Model carries its own history
Automatic provenance (git, env, argv)
Best-epoch weights auto-saved ⚠️
Inference state bundled in checkpoint
TensorBoard logging
Gradient accumulation ⚠️
OOM-tolerant training
High-level fit() 👍
Early stopping 👍
Multi-GPU / distributed training ⚠️
Mixed precision (AMP) 👍 ⚠️
Callback / hook system
LR finder 👍

✅ built-in · 👍 optional · ⚠️ partial or via plugin · ❌ not supported


The Mentee API

Properties

model.current_epoch   # int — len(train_history)
model.device          # torch.device — inferred from parameters
model.optimizer       # cached optimizer (from trainer or create_train_objects)
model.lr_scheduler    # cached LR scheduler
model.loss_fn         # cached default loss function

Core methods to implement in your subclass

def training_step(self, sample) -> tuple[Tensor, dict[str, float]]: ...
def validation_step(self, sample)    -> dict[str, float]: ...

Both are optional when self.trainer is set — the trainer's classmethods are used instead.

The first key of the returned dict is the principal metric used for best-checkpoint selection.

Training

model.create_train_objects(lr=1e-3, step_size=10, gamma=0.1)
# -> {"optimizer": Adam, "lr_scheduler": StepLR, "loss_fn": <fn or None>}

model.train_epoch(dataset, optimizer,
                  lr_scheduler=None, pseudo_batch_size=1,
                  memfail="raise", tensorboard_writer=None,
                  verbose=False, refresh_freq=20)

model.validate_epoch(dataset,
                     recalculate=False, memfail="raise",
                     tensorboard_writer=None, verbose=False, refresh_freq=20)

Checkpointing

model.save("run.pt", optimizer=optimizer, lr_scheduler=scheduler)

# load weights only (no optimizer)
model = MyNet.resume("run.pt", model_class=MyNet)

# full resume (weights + optimizer + scheduler, moved to device)
model, optimizer, scheduler = MyNet.resume_training(
    "run.pt", model_class=MyNet, device="cuda", lr=1e-3
)

All tensors are saved on CPU regardless of the training device.


Built-in trainers

MentorTrainer is a pure-Python strategy class (not an nn.Module) that is composed into a Mentee via self.trainer. It separates stateless logic (classmethods) from stateful training objects (cached per-instance):

Trainer Default loss Metrics
Classifier nn.CrossEntropyLoss loss, acc
Regressor nn.MSELoss loss, rmse
from mentor import Mentee, Classifier, Regressor

class MyClassifier(Mentee):
    def __init__(self, num_classes=10):
        super().__init__(num_classes=num_classes)
        self.fc = nn.Linear(128, num_classes)
        self.trainer = Classifier()
    def forward(self, x): return self.fc(x.flatten(1))

class MyRegressor(Mentee):
    def __init__(self, in_features=8):
        super().__init__(in_features=in_features)
        self.fc = nn.Linear(in_features, 1)
        self.trainer = Regressor()
    def forward(self, x): return self.fc(x).squeeze(-1)

Custom trainers can be added by subclassing MentorTrainer and implementing default_training_step (classmethod) and create_train_objects.


Included examples — CIFAR-10 with ResNet

cd examples/cifar

# full control — custom training_step
python train_cifar.py -resume_path ./runs/cifar.pt -epochs 20 -verbose

# minimal — uses Classifier trainer
python train_cifar_classifier.py -resume_path ./runs/cifar2.pt -epochs 20 -verbose

Reporting CLI

After installation a command-line tool is registered:

mtr_report_file -path ./runs/cifar.pt

Example output:

Checkpoint: /runs/cifar.pt
File size:  89.3 KB

Model class:   examples.cifar.train_cifar.CifarResNet
Importable:    OK (found in 'examples.cifar.train_cifar')

Architecture (inferred from state_dict):
  Parameters:   11,181,642 in 122 tensors
  Modules:      61 parameter-bearing
  Input:        3 channels  (inferred from first conv)
  Output:       10 features  (inferred from last linear)

Epochs trained: 5
  First epoch:  accuracy=0.1823  loss=2.2987  memfails=0.0000
  Last epoch:   accuracy=0.6341  loss=1.0201  memfails=0.0000
...

Design philosophy

  • No magic: Mentee is a plain nn.Module. Models work identically whether used through the framework or as bare PyTorch modules.
  • Single file: one .pt file holds everything. No sidecar JSON, no separate history database.
  • You own the loop: train_epoch and validate_epoch are helpers, not a Trainer class. Call them however you like.
  • Composition over inheritance: trainers are strategy objects assigned to self.trainer, not base classes. A model is always a Mentee; a trainer is always a MentorTrainer.
  • Reproducibility first: every change in git hash, environment, or invocation is recorded automatically.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_mentor-0.2.2.tar.gz (41.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_mentor-0.2.2-py3-none-any.whl (39.9 kB view details)

Uploaded Python 3

File details

Details for the file torch_mentor-0.2.2.tar.gz.

File metadata

  • Download URL: torch_mentor-0.2.2.tar.gz
  • Upload date:
  • Size: 41.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for torch_mentor-0.2.2.tar.gz
Algorithm Hash digest
SHA256 1caa7fa3a28abc7250043e76e566b4c593a91675bfa40e4b4f29e2534c003b67
MD5 851be05455ebf08d528011bcbf3cb2a3
BLAKE2b-256 447348f6332ded881a550409d7df777899d5851bd0a655b7f0e6408f8362bf9f

See more details on using hashes here.

File details

Details for the file torch_mentor-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: torch_mentor-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 39.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.11

File hashes

Hashes for torch_mentor-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 7aa42e6a55753b812cfecd65ec36e47959eae8c0c28692fca5352ef313538670
MD5 58e110b5daed642ee6ee29fb4c493160
BLAKE2b-256 036bde8516600c0d7413d4bd65bbd13324bbbae91a28e0209a0a65df01e9c3c2

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page