Skip to main content

A lightweight trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.

Project description

🏋️‍♂️ EpochEngine: A Lightweight Training Framework for PyTorch

PyPI Publish License pre-commit

EpochEngine is a minimal yet flexible library for training and validating PyTorch models, currently focused on computer vision classification tasks. It provides a simple API to configure models, optimizers, schedulers, and metrics while handling logging, checkpointing, and plotting automatically.

The project is currently under active development, more changes expected.

✨ Key Features

  • 🔧 Config-driven setup — pass model, loss, optimizer, scheduler, and metrics in a single config.
  • 💾 Automatic checkpointing — saves model weights, optimizer state and other important data after each epoch, with easy resuming via run ID.
  • 📊 Metric tracking & plotting — logs training/validation metrics to JSON and generates plots at the end of training.
  • 🚀 Resuming training — continue any previous run by providing its run ID.
  • 🧩 Extensible — easily add custom models, losses, optimizers, schedulers, and metrics.

Designed to be lightweight, transparent, and beginner-friendly, without the overhead of larger frameworks like PyTorch Lightning.

Installation

The package can be installed as follows:

# Installing the main package
pip install epoch-engine

# Installing additional optional dependencies
pip install epoch-engine[build,linters]

Development mode

# Cloning the repo and moving to the repo dir
git clone https://github.com/spolivin/epoch-engine.git
cd epoch-engine

# Installing the package and optional deps (dev mode)
pip install -e .
pip install -e .[build,linters]

Pre-commit support

The repository provides support for running pre-commit checks via hooks defined in .pre-commit-config.yaml. These can be loaded in the current git repository by running:

pre-commit install

pre-commit will already be loaded to the venv after running pip install epoch-engine[linters] or pip install -e .[linters]

Python API

Let's suppose that we have constructed a model in PyTorch called net (in the example below we can just take already preset blocks for building ResNet model) and set up the loss function we would like to optimize:

import torch.nn as nn

from epoch_engine.models import BasicBlock, ResNet

# Instantiating a ResNet model for gray-scale images
net = ResNet(
    in_channels=1,
    block=BasicBlock,
    num_blocks=[3, 3, 3],
    num_classes=10,
)
loss_func = nn.CrossEntropyLoss()

We have also the already prepared dataloaders: train_loader and valid_loader for training and validation sets respectively. Then, we can set up the Trainer in two major ways:

  • Multi-step setup
  • Config-based setup

Multi-step setup

Instantiating Trainer

from epoch_engine.core import Trainer

# Instantiating a Trainer (with auto device detection)
trainer = Trainer(
    model=net,
    criterion=loss_func,
    train_loader=train_loader,
    valid_loader=valid_loader,
    train_on="auto",
)

Parameter train_on will detect whether CUDA, MPS or CPU is to be used (when initialized with "auto").

Optimizer and scheduler

The next step is to configure optimizer and (optionally) scheduler which can be done in one of two ways:

Configs
from epoch_engine.core import OptimizerConfig, SchedulerConfig

# Setting up configs for optimizer and scheduler
optimizer_config = OptimizerConfig(
    optimizer_class=torch.optim.SGD,
    optimizer_params={"lr": 0.25, "momentum": 0.75},
)
scheduler_config = SchedulerConfig(
    scheduler_class=torch.optim.lr_scheduler.StepLR,
    scheduler_params={"gamma": 0.1, "step_size": 2},
    scheduler_level="epoch",
)

trainer.configure_trainer(
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config,
)
Direct
trainer.configure_trainer(
    optimizer_class=torch.optim.SGD,
    optimizer_params={"lr": 0.25, "momentum": 0.75},
    scheduler_class=torch.optim.lr_scheduler.StepLR,
    scheduler_params={"gamma": 0.1, "step_size": 2},
    scheduler_level="epoch",
)

Metrics registration

By default only loss is computed but we can also add extra metrics to track during Trainer run:

from sklearn.metrics import accuracy_score, f1_score, precision_score

trainer.register_metrics({
    "accuracy": accuracy_score,
    "precision": lambda y_true, y_pred: precision_score(y_true, y_pred, average="macro"),
    "f1": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro"),
})

Important thing here is for the passed dict to map metric name we want to see in the logs to the callable objects which in turn map targets and predictions to floats.

Advanced

One can also register metrics via a custom MetricConfig which is similar to passing a dictionary but with an additional plot parameter for controlling for which metrics to generate plots (in case of passing a dictionary, plot=True is by default):

from epoch_engine.core.configs import MetricConfig

trainer.register_metrics([
    MetricConfig(name="accuracy", fn=accuracy_score, plot=False),
])

In the example above we are registering an additional accuracy metric and specifically state that no plot need to be generated at the end of training run.

Config-based setup (recommendable)

In version 0.1.3 an additional option of creating a Trainer object has been added where we just specify all required parameters in a TrainerConfig and then pass it to Trainer. This will set up all things so that training could be launched without any more settings:

from epoch_engine.core.configs import TrainerConfig

trainer_config = TrainerConfig(
    model=net,
    criterion=nn.CrossEntropyLoss(),
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer_config=OptimizerConfig(
        optimizer_class=torch.optim.SGD,
        optimizer_params={"lr": 0.25, "momentum": 0.75},
    ),
    scheduler_config=SchedulerConfig(
        scheduler_class=torch.optim.lr_scheduler.StepLR,
        scheduler_params={"gamma": 0.1, "step_size": 2},
    ),
    metrics={
        "accuracy": accuracy_score,
        "precision": lambda y_true, y_pred: precision_score(
            y_true, y_pred, average="macro"
        ),
        "f1": lambda y_true, y_pred: f1_score(
            y_true, y_pred, average="macro"
        ),
    },
)

trainer = Trainer.from_config(config=trainer_config)

Running the Trainer

New run

Now, we launch the training for the first time which will show the progress bar (if enabled):

# Launching training
trainer.run(
    epochs=5,
    run_id=None,
    seed=42,
    enable_tqdm=True,
)

Running the Trainer in this way will set up in the current directory the following directory structure:

current-dir/
├── runs/
    ├── run_id=82cc72/
    │    ├── checkpoints/
    │    |   ├── ckpt_epoch_1.pt
    │    |   ├── ckpt_epoch_2.pt
    │    |   ├── ckpt_epoch_3.pt
    │    |   ├── ckpt_epoch_4.pt
    │    |   └── ckpt_epoch_5.pt
    |    ├── results_accuracy.png
    |    ├── results_f1.png
    |    ├── results_loss.png
    |    └── results_precision.png
    |
    └── training_history.json

At the beginning of the run, a new run_id is generated (if run_id=None) and in the current folder the method creates runs folder with a separate folder for the files related to the current run (in the above example folder named run_id=82cc72 with the generated run ID). After each epoch the checkpoint (containing last trained epoch, Trainer's new run ID, model parameters and optimizer state) is saved to checkpoints subfolder.

At the end of the training, the plots for the registered metrics are saved as well (new in 0.1.3).

Additionally, the losses for each data set as well as the registered metrics for each epoch and training run are saved to runs/training_history.json and are written to each run. Such results can look for instance like this:

{
     "runs": [
          {
               "run_id=82cc72": [
                    {
                         "loss/train": 0.7220337147848059,
                         "loss/valid": 0.056390782489593255,
                         "accuracy/train": 0.7354,
                         "accuracy/valid": 0.9828888888888889,
                         "precision/train": 0.745122229756801,
                         "precision/valid": 0.9828108100143986,
                         "f1/train": 0.7367284983928013,
                         "f1/valid": 0.982719354776437,
                         "epoch": 1
                    },
               ]
          }
     ]
}

To make the training results representation readable, the above output shows the training/validation results for only one epoch but in case of training for more epochs, there would be a longer list of such dictionaries distinguishable by epoch.

One can in fact quickly generate the plot of a specific metric for each epoch in case of its absence using a specific function for that:

from epoch_engine.core.plotting import generate_plot_from_json

generate_plot_from_json(metric_name="accuracy", run_id="82cc72")

Resuming training

The training can be easily resumed by specifying the run_id from which we would like to continue training (resuming training only from the last logged epoch is supported):

trainer.run(
    epochs=2,
    run_id="82cc72",
    seed=42,
    enable_tqdm=True,
)

The new checkpoints will be saved to the same folder for this run and new metrics will be appended to the same run ID's data in training_process.json. Additionally, the respective plots will also be updated.

Test script

The basics of the developed API are presented in the run_trainer.py I built in the root of the repository. It can be run for instance as follows:

# Installing scikit-learn for using metrics
pip install -r requirements.txt

# Running the Trainer
python run_trainer.py --model=resnet --epochs=3

The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in checkpoints directory.

TODOs

  • Add an option of storing Trainer's runs and metrics in local SQLite database with API to access it conveniently
  • Add logging support via TensorBoard and/or Weights & Biases
  • Re-format the output of training_history.json to allow for more details and convenient and easy-to-read representation of metrics
  • Add lightweight tests
  • Introduce an option to train using Automatic Mixed Precision (AMP)
  • Add plots generation for registered metrics within the tracking/logging system

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

epoch_engine-0.1.3.tar.gz (22.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

epoch_engine-0.1.3-py3-none-any.whl (22.0 kB view details)

Uploaded Python 3

File details

Details for the file epoch_engine-0.1.3.tar.gz.

File metadata

  • Download URL: epoch_engine-0.1.3.tar.gz
  • Upload date:
  • Size: 22.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for epoch_engine-0.1.3.tar.gz
Algorithm Hash digest
SHA256 08c5d64e3ead042a71cba64aea300906d73eaba36a64a39fb30bc64e7df06761
MD5 7a512923a2e6285dd30ead3b0b7ad8ab
BLAKE2b-256 9ee1e67eda1dd51ad3029f8732d300a124cf84499075a72369c94aea0544f615

See more details on using hashes here.

Provenance

The following attestation bundles were made for epoch_engine-0.1.3.tar.gz:

Publisher: publish.yml on spolivin/epoch-engine

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file epoch_engine-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: epoch_engine-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 22.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.12.9

File hashes

Hashes for epoch_engine-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 49456464c10f57a0bb5db8c5cfdd6eaa74e7af798a34b17c4c074272d46ea074
MD5 b87559de1864ede664f3abcacd1b298f
BLAKE2b-256 fa6380c0bb00d732c4deae4db712c26569b490be2520d0439c50f27d833b87ab

See more details on using hashes here.

Provenance

The following attestation bundles were made for epoch_engine-0.1.3-py3-none-any.whl:

Publisher: publish.yml on spolivin/epoch-engine

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page