A lightweight trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.
Project description
๐๏ธโโ๏ธ EpochEngine: A Lightweight Training Framework for PyTorch
EpochEngine is a minimal yet flexible library for training and validating PyTorch models, currently focused on computer vision classification tasks. It wraps the PyTorch training loop with automatic device detection, checkpointing, metric tracking, AMP support, and a callback system โ without hiding your optimizer or model behind extra config layers.
The project is currently under active development, more changes expected.
๐ Motivation behind the project
Training deep learning models often requires repeating the same boilerplate code: device management, checkpointing, resuming runs, logging results, and tracking metrics. Existing frameworks like PyTorch Lightning or Hugging Face Trainer are powerful, but they can feel heavy for smaller projects or quick experiments.
This library started as a lightweight training framework I could use in my own computer vision experiments.
It provides:
- A single constructor that accepts your model, optimizer, and loaders directly โ no wrapper configs required
- Automatic run management (checkpoints, plots, JSON history)
- Easy resume support for interrupted or already finished runs
- A simple but extendable interface for adding metrics, callbacks, and training options
The goal is not to replace bigger frameworks, but to make my own workflow faster, cleaner, and more reproducible โ and to practice designing a training system from scratch.
โจ Key Features
- ๐ง Simple setup โ pass your model, optimizer, scheduler, and metrics directly to
Trainerโ no wrapper configs needed. - ๐พ Automatic checkpointing โ saves model weights and optimizer state after each epoch, with easy resuming via run ID.
- ๐ Metric tracking & plotting โ logs training/validation metrics to JSON and generates plots at the end of training.
- ๐ Resuming training โ continue any previous run by providing its run ID.
- โก Mixed precision training โ supports AMP with
torch.ampandGradScaler. - ๐ Callbacks โ built-in early stopping, best checkpoint saving, checkpoint pruning, and NaN detection; fully extensible with custom callbacks.
- ๐งฉ Extensible โ easily add custom models, losses, optimizers, schedulers, and metrics.
Designed to be lightweight, transparent, and beginner-friendly, without the overhead of larger frameworks like PyTorch Lightning.
Installation
The package can be installed as follows:
# Installing the main package
pip install epoch-engine
# Installing additional optional dependencies
pip install epoch-engine[build,linters]
Development mode
The project uses uv for local development. After cloning the repo, install all dependencies (including optional extras) with:
# Cloning the repo and moving to the repo dir
git clone https://github.com/spolivin/epoch-engine.git
cd epoch-engine
# Installing the package and all optional deps (dev mode)
uv sync --all-extras
uvresolves PyTorch from the CPU-only index by default, avoiding the large CUDA download during development. If you need CUDA support locally, install torch manually:uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
Alternatively, with plain pip:
pip install -e .[build,linters,metrics,test]
Pre-commit support
The repository provides support for running pre-commit checks via hooks defined in .pre-commit-config.yaml. These can be loaded in the current git repository by running:
pre-commit install
pre-commitwill already be available after runninguv sync --all-extrasorpip install epoch-engine[linters]
A pre-push hook is also configured to run the test suite automatically before each git push, but only when the push includes at least one .py file. To activate it:
pre-commit install --hook-type pre-push
Python API
The examples below assume a model, a loss function, and dataloaders are already prepared. EpochEngine ships two built-in architectures; here we use ResNet as a starting point:
from epoch_engine.models import ResNet
net = ResNet(
in_channels=1,
num_blocks=[3, 3, 3],
num_classes=10,
)
train_loader and valid_loader are standard PyTorch DataLoader instances for the training and validation sets respectively. test_loader is optional and only needed for trainer.evaluate().
Trainer set-up
Instantiate Trainer directly by passing your PyTorch objects โ no wrapper configs needed:
import torch
import torch.nn as nn
from epoch_engine.core import Trainer
optimizer = torch.optim.SGD(net.parameters(), lr=0.25, momentum=0.75)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=2)
trainer = Trainer(
model=net,
criterion=nn.CrossEntropyLoss(),
optimizer=optimizer,
train_loader=train_loader,
valid_loader=valid_loader,
test_loader=test_loader, # optional
scheduler=scheduler, # optional
scheduler_level="epoch", # optional, "epoch" (default) or "batch"
metrics=metrics, # optional, see below
callbacks=callbacks, # optional, see below
enable_amp=True, # optional, AMP only activates on CUDA
)
Trainerwill automatically detect whether CUDA, MPS or CPU is to be used. Theschedulerargument is optional; omitting it disables LR scheduling.
Metrics
By default only loss is tracked. Extra metrics can be added by passing a dict mapping a metric name to a callable with signature (y_true, y_pred) -> float:
from sklearn.metrics import accuracy_score, f1_score, precision_score
metrics = {
"accuracy": accuracy_score,
"precision": lambda y_true, y_pred: precision_score(y_true, y_pred, average="macro"),
"f1": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro"),
}
When using the dict form, plots are generated for all metrics at the end of training.
Fine-grained control with MetricConfig
Use MetricConfig when you want to selectively disable plots for specific metrics:
from epoch_engine.core.configs import MetricConfig
metrics = [
MetricConfig(name="accuracy", fn=accuracy_score, plot=True),
MetricConfig(name="f1", fn=f1_metric, plot=False), # tracked but not plotted
]
Callbacks
Callbacks let you hook into the training lifecycle without modifying the Trainer. Pass a list of callback instances to the callbacks argument of the constructor:
from epoch_engine.core.callbacks import EarlyStopping, BestCheckpoint, CheckpointPruner, NanCallback
trainer = Trainer(
...,
callbacks=[
BestCheckpoint(monitor="loss/valid", mode="min"),
EarlyStopping(monitor="loss/valid", patience=5, mode="min"),
CheckpointPruner(keep_last_n=2),
],
)
Built-in callbacks
EarlyStopping
Stops training when a monitored metric stops improving for patience consecutive epochs. Sets trainer.interrupted = True when triggered.
EarlyStopping(
monitor="loss/valid", # metric key to watch
patience=5, # epochs without improvement before stopping
mode="min", # 'min' or 'max'
min_delta=0.001, # minimum change to count as improvement
verbose=True, # log warnings when patience counter increments
)
BestCheckpoint
Copies the current epoch checkpoint to best.pt whenever the monitored metric reaches a new best. Required before using resume_from_best=True.
BestCheckpoint(
monitor="loss/valid", # metric key to watch
mode="min", # 'min' or 'max'
min_delta=0.0,
verbose=True,
)
CheckpointPruner
Deletes old ckpt_epoch_*.pt files after each save, keeping only the most recent keep_last_n. best.pt is never removed.
CheckpointPruner(keep_last_n=2)
NanCallback
Stops training immediately if any metric value is NaN.
NanCallback(
monitor=["loss/train", "loss/valid"], # specific keys to check; None checks all
verbose=True,
)
Custom callbacks
Subclass Callback and override whichever hooks you need. Only on_epoch_end and on_checkpoint_save are currently dispatched by the trainer; the remaining hooks are available as no-op stubs for future use.
from epoch_engine.core.callbacks import Callback
class PrintLR(Callback):
def on_epoch_end(self, trainer, epoch, metrics):
lr = trainer.optimizer.param_groups[0]["lr"]
print(f"Epoch {epoch} โ lr={lr:.6f}")
To halt training from a custom callback, set self.should_stop = True inside on_epoch_end.
Running the Trainer
New run
Now, we launch the training for the first time which will show the progress bar (if enabled):
# Launching training (with gradient clipping)
trainer.run(
epochs=5,
seed=42,
enable_tqdm=True,
clip_grad_norm=1.0,
)
Running the Trainer in this way will set up in the current directory the following directory structure:
current-dir/
โโโ runs/
โโโ run_id=82cc72/
โ โโโ checkpoints/
โ | โโโ ckpt_epoch_1.pt
โ | โโโ ckpt_epoch_2.pt
โ | โโโ ckpt_epoch_3.pt
โ | โโโ ckpt_epoch_4.pt
โ | โโโ ckpt_epoch_5.pt
| โโโ plots/
| โโโ accuracy.png
| โโโ f1.png
| โโโ loss.png
| โโโ precision.png
|
โโโ metrics_history.json
At the beginning of the run, a new run_id is generated (if run_id=None) and in the current folder the method creates runs folder with a separate folder for the files related to the current run (in the above example folder named run_id=82cc72 with the generated run ID). After each epoch the checkpoint (containing last trained epoch, model parameters and optimizer state) is saved to checkpoints subfolder.
At the end of training, plots for the registered metrics are saved in the run-specific plots directory.
Additionally, the losses for each data set as well as the registered metrics for each epoch and training run are saved to runs/metrics_history.json and are written to each run. Such results can look for instance like this:
{
"runs": [
{
"run_id=82cc72": [
{
"loss/train": 0.7220337147848059,
"loss/valid": 0.056390782489593255,
"accuracy/train": 0.7354,
"accuracy/valid": 0.9828888888888889,
"precision/train": 0.745122229756801,
"precision/valid": 0.9828108100143986,
"f1/train": 0.7367284983928013,
"f1/valid": 0.982719354776437,
"epoch": 1
}
]
}
]
}
To make the training results representation readable, the above output shows the training/validation results for only one epoch but in case of training for more epochs, there would be a longer list of such dictionaries distinguishable by
epoch.
Metrics display
After each epoch, the trainer prints a formatted table with train and valid values for every tracked metric, side by side:
======================================================================
Epoch 3 Results:
----------------------------------------------------------------------
accuracy | Train: 0.7354 | Valid: 0.9829
f1 | Train: 0.7367 | Valid: 0.9827
loss | Train: 0.7220 | Valid: 0.0564
precision | Train: 0.7451 | Valid: 0.9828
======================================================================
The table is printed via tqdm.write when enable_tqdm=True (the default), so it doesn't interfere with the progress bars. Pass enable_tqdm=False to use plain print instead.
Resuming training
The training can be easily resumed by specifying the run_id from which we would like to continue training:
trainer.run(
epochs=2,
run_id="82cc72",
seed=42,
enable_tqdm=True,
)
The new checkpoints will be saved to the same folder for this run and new metrics will be appended to the same run ID's data in
runs/metrics_history.json. Additionally, the respective plots will also be updated.
Alternatively, if resuming training using the same Trainer instance, then we can just omit run_id whatsoever (equivalent to setting run_id=None), in which case Trainer will automatically infer that training for the last assigned run_id should be continued:
trainer.run(epochs=2)
Resuming from the best checkpoint
By default, resuming always loads the last epoch checkpoint. If you used the BestCheckpoint callback in a prior run, you can instead resume from best.pt:
trainer.run(epochs=2, resume_from_best=True)
If
best.ptdoes not exist (e.g.BestCheckpointwas not used), the trainer logs a warning and falls back to the last epoch checkpoint automatically.
Testing the trained model
After the model has been trained and validated using Trainer, we can quickly test it on test set either in this way:
test_metrics = trainer.evaluate()
or by passing a loader directly if test_loader was not set at construction time:
test_metrics = trainer.evaluate(loader=test_loader)
Running tests
The test suite uses pytest. Install the test dependencies first, then run:
# Install with test extras (uv)
uv sync --extra test
# or with pip
pip install -e ".[test]"
# Run all tests
pytest tests/
# Run a specific test file
pytest tests/test_trainer.py -v
# Run a single test
pytest tests/test_trainer.py::TestTrainer::test_run -v
Test script
The basics of the developed API are presented in the run_trainer.py script in the root of the repository. It can be run via the provided Makefile targets:
# Run with EDNet model (default 5 epochs)
make run-trainer-ednet
# Run with ResNet model (default 5 epochs)
make run-trainer-resnet
# Override the number of epochs
make run-trainer-ednet EPOCHS=3
make run-trainer-resnet EPOCHS=10
The training will be launched on the automatically detected device (CUDA โ MPS โ CPU). Metric plots are generated at the end of each run.
If make is not available (e.g. on Windows), run the script directly:
python run_trainer.py --model resnet --epochs 5 --plot-metrics
python run_trainer.py --model ednet --epochs 5 --plot-metrics
TODOs
- Add gradient clipping
- Add lightweight tests
- Simplify trainer setup โ accept optimizer/scheduler directly, removing
OptimizerConfig,SchedulerConfig, andTrainerConfig - Change the structure of the generated
runsdirectory to allow for more convenient structure - Introduce an option to train using Automatic Mixed Precision (AMP)
- Add plots generation for registered metrics within the tracking/logging system
- Introduce callbacks for early stopping training, saving best checkpoint, etc.
- Come up with a way to track metrics live during training/validation
- Introduce multi-GPU training
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file epoch_engine-0.2.0.tar.gz.
File metadata
- Download URL: epoch_engine-0.2.0.tar.gz
- Upload date:
- Size: 48.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
88774fd86473fa9b11c233769e5c43c8574639c0a9571ebb3753106f0d372ed1
|
|
| MD5 |
1ceb56e07d09915ef78a324da456e6a2
|
|
| BLAKE2b-256 |
98e58b45084d68cd1673efed47fb097ae335e8434b703cc1d83fa70a6f9d9fa5
|
Provenance
The following attestation bundles were made for epoch_engine-0.2.0.tar.gz:
Publisher:
publish.yml on spolivin/epoch-engine
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
epoch_engine-0.2.0.tar.gz -
Subject digest:
88774fd86473fa9b11c233769e5c43c8574639c0a9571ebb3753106f0d372ed1 - Sigstore transparency entry: 1004119930
- Sigstore integration time:
-
Permalink:
spolivin/epoch-engine@35576ac326e12a8335e62c3f1f1f469ccb00aa4c -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/spolivin
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@35576ac326e12a8335e62c3f1f1f469ccb00aa4c -
Trigger Event:
release
-
Statement type:
File details
Details for the file epoch_engine-0.2.0-py3-none-any.whl.
File metadata
- Download URL: epoch_engine-0.2.0-py3-none-any.whl
- Upload date:
- Size: 28.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0dcd90ef1522d33a5f489bea22f736262e147260f45b41361b34fc2e57cfeb12
|
|
| MD5 |
ddb8bb8182e63711ba6cfd01e0f38e88
|
|
| BLAKE2b-256 |
adb5f5b4616a8b43619dbb63352a1fc13f39f27fe2d21cc9be41014c2afed69a
|
Provenance
The following attestation bundles were made for epoch_engine-0.2.0-py3-none-any.whl:
Publisher:
publish.yml on spolivin/epoch-engine
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
epoch_engine-0.2.0-py3-none-any.whl -
Subject digest:
0dcd90ef1522d33a5f489bea22f736262e147260f45b41361b34fc2e57cfeb12 - Sigstore transparency entry: 1004119932
- Sigstore integration time:
-
Permalink:
spolivin/epoch-engine@35576ac326e12a8335e62c3f1f1f469ccb00aa4c -
Branch / Tag:
refs/tags/v0.2.0 - Owner: https://github.com/spolivin
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@35576ac326e12a8335e62c3f1f1f469ccb00aa4c -
Trigger Event:
release
-
Statement type: