Skip to main content

Trainer to optimize PyTorch models

Project description

GitHub tag (latest by date) GitHub code size in bytes GitHub issues workflow doc Code style: black

torchfitter is a simple library to ease the training of PyTorch models. It features a class called Trainer that includes the basic functionality to fit models in a Keras-like style.

Internally, torchfitter leverages the power of accelerate to handle the device management.

The library also provides a callbacks API that can be used to interact with the model during the training process, as well as a set of basic regularization procedures.

Installation

Normal user

pip install torchfitter

This library does not ship CUDA nor XLA. Follow the official PyTorch documentation for more information about how to install CUDA binaries.

Developer

git clone https://github.com/Xylambda/torchfitter.git
pip install -e torchfitter/. -r torchfitter/requirements-dev.txt

Tests

To run the tests you must install the library as a developer.

cd torchfitter/
pytest -v tests/

Features

Supported Not supported Planned
Basic training loop x
Gradient Clipping x
Gradient Accumulation x
Multi-device support x
Regularization x
In-loop metrics support x
Mixed precision training x
Callbacks System x
Hyperparameter search x
Warm Training x x

Usage

Assume we already have DataLoaders for the train and validation sets.

from torch.utils.data import DataLoader


train_loader = DataLoader(...)
val_loader = DataLoader(...)

Then, create the optimizer and the loss criterion as usual. Pass them to the trainer along the PyTorch model. You can also add a regularization procedure if you need/want to do it. The same goes for callbacks: create the desired callbacks and pass them to the trainer as a list.

import torch.nn as nn
import torch.optim as optim
from torchfitter.trainer import Trainer
from torchfitter.callbacks import (
    LoggerCallback,
    EarlyStopping,
    LearningRateScheduler,
    L1Regularization,
)

model = nn.Linear(in_features=1, out_features=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
l1_reg = L1Regularization(regularization_rate=0.01, biases=False)

# callbacks
logger = LoggerCallback(update_step=50)
early_stopping = EarlyStopping(patience=50, load_best=True, path='checkpoint.pt')
scheduler = LearningRateScheduler(
    scheduler=optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.9)
)

trainer = Trainer(
    model=model, 
    criterion=criterion,
    optimizer=optimizer,
    mixed_precision="fp16",
    accumulate_iter=4, # accumulate gradient every 4 iterations,
    gradient_clipping='norm',
    gradient_clipping_kwrgs={'max_norm': 1.0, 'norm_type': 2.0},
    callbacks=[l1_reg, scheduler, early_stopping, logger]
)

history = trainer.fit(train_loader, val_loader, epochs=1000)

Since torchfitter leverages the power of accelerate, the device management will rely on the latter. You can pass your own accelerate.Accelerator object to fine tune its parameters:

from accelerate import Accelerator
from torchfitter.trainer import Trainer


accelerator = Accelerator(...)
trainer = Trainer(
    accelerator=accelerator,
    **kwargs
)

Callbacks

Callbacks allow you to interact with the model during the fitting process. They provide with different methods that are called at different stages. To create a callback simply extend the base class and fill the desired methods.

import torch
from torchfitter.conventions import ParamsDict
from torchfitter.callbacks.base import Callback


class ModelCheckpoint(Callback):
    def __init__(self, path):
        super(ModelCheckpoint, self).__init__()

        self.path = path

    def __repr__(self) -> str:
        return "ModelCheckpoint()"

    def on_epoch_end(self, params_dict):
        # get params
        accelerator = params_dict[ParamsDict.ACCELERATOR]
        epoch = params_dict[ParamsDict.EPOCH_NUMBER]

        # ensure model is safe to save
        _model = params_dict[ParamsDict.MODEL]
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(_model)

        # actual saving
        fname = self.path / f'model_epoch_{epoch}.pt'
        accelerator.save(unwrapped_model.state_dict(), fname)

Each method receives params_dict, which is a dictionary object containing the internal training parameters. You can see the pair key value of each parameter of the conventions:

>>> from torchfitter.conventions import ParamsDict
>>> [(x, getattr(ParamsDict, x)) for x in ParamsDict.__dict__ if not x.startswith('__')]

And you can also check the doc to understand the meaning of each one of the parameters:

>>> from torchfitter.conventions import ParamsDict
>>> print(ParamsDict.__doc__)

NOTE: the callbacks design can be considered as a port from Keras design. I AM NOT the author of this callback sysem design despite the fact that I made some minor design changes. Find more in the Credits section.

FAQ

  • Do you know Pytorch-Lightning/FastAI?

I know them and I think they are awesome. This is a personal project though I must say the trainer is reasonably well-equiped.

  • Why is the validation loader not optional?

Because I think it enforces good ML practices that way.

  • Why didn't you implement the optimization steps in the model object?

It is certainly another approach you may take when building an optimization loop (PyTorch-Lightning works this way), but I don't like my abstract data types to track way too many things in addition to being torch.nn.Module types. Functionality should be clear and atomic: the model tracks gradients and the trainer cares about the optimization process.

  • I have a suggestion/question

Thank you! Do not hesitate to open an issue and I'll do my best to answer you.

CREDITS

Cite

If you've used this library for your projects please cite it:

@misc{alejandro2019torchfitter,
  title={torchfitter - Simple Trainer to Optimize PyTorch Models},
  author={Alejandro Pérez-Sanjuán},
  year={2020},
  howpublished={\url{https://github.com/Xylambda/torchfitter}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchfitter-4.3.0.tar.gz (44.7 kB view hashes)

Uploaded Source

Built Distribution

torchfitter-4.3.0-py3-none-any.whl (31.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page