Skip to main content

Trainer to optimize PyTorch models

Project description

GitHub tag (latest by date) GitHub code size in bytes GitHub issues workflow doc Code style: black

torchfitter is a simple library to ease the training of PyTorch models. It features a class called Trainer that includes the basic functionality to fit models in a Keras-like style.

Internally, torchfitter leverages the power of accelerate to handle the device management.

The library also provides a callbacks API that can be used to interact with the model during the training process, as well as a set of basic regularization procedures.

Additionally, you will find the Manager class which allows you to run multiple experiments for different random seeds.

Installation

Normal user

pip install torchfitter

This library does not ship CUDA nor XLA. Follow the official PyTorch documentarion for more information about how to install CUDA binaries.

Developer

git clone https://github.com/Xylambda/torchfitter.git
pip install -e torchfitter/. -r torchfitter/requirements-dev.txt

Tests

To run the tests you must install the library as a developer.

cd torchfitter/
pytest -v tests/

Features

Supported Not supported Planned
Basic training loop x
Gradient Clipping x
Gradient Accumulation x
Multi-device support x
Regularization x
In-loop metrics support x
Mixed precision training x
Callbacks System x
Hyperparameter search x
Warm Training x x

Usage

Assume we already have DataLoaders for the train and validation sets.

from torch.utils.data import DataLoader


train_loader = DataLoader(...)
val_loader = DataLoader(...)

Then, create the optimizer and the loss criterion as usual. Pass them to the trainer along the PyTorch model. You can also add a regularization procedure if you need/want to do it. The same goes for callbacks: create the desired callbacks and pass them to the trainer as a list.

import torch.nn as nn
import torch.optim as optim
from torchfitter.trainer import Trainer
from torchfitter.regularization import L1Regularization
from torchfitter.callbacks import (
    LoggerCallback,
    EarlyStopping,
    LearningRateScheduler
)

model = nn.Linear(in_features=1, out_features=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
regularizer = L1Regularization(regularization_rate=0.01, biases=False)

# callbacks
logger = LoggerCallback(update_step=50)
early_stopping = EarlyStopping(patience=50, load_best=True, path='checkpoint.pt')
scheduler = LearningRateScheduler(
    scheduler=optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.9)
)

trainer = Trainer(
    model=model, 
    criterion=criterion,
    optimizer=optimizer, 
    regularizer=regularizer,
    mixed_precision=True,
    accumulate_iter=4, # accumulate gradient every 4 iterations,
    gradient_clipping='norm',
    gradient_clipping_kwrgs={'max_norm': 1.0, 'norm_type': 2.0},
    callbacks=[logger, early_stopping, scheduler]
)

history = trainer.fit(train_loader, val_loader, epochs=1000)

Since torchfitter leverages the power of accelerate, the device management will rely on the latter. You can pass your own accelerate.Accelerator object to fine tune its parameters:

from accelerate import Accelerator
from torchfitter.trainer import Trainer


accelerator = Accelerator(...)
trainer = Trainer(
    accelerator=accelerator,
    **kwargs
)

Regularization

TorchFitter includes regularization algorithms but you can also create your own procedures. To create your own algorithms you just:

  1. Inherit from RegularizerBase and call the super operator appropiately.
  2. Implement the procedure in the compute_penalty method.

Here's an example implementing L1 from scratch:

import torch
from torchfitter.regularization.base import RegularizerBase


class L1Regularization(RegularizerBase):
    def __init__(self, regularization_rate, biases=False):
        super(L1Regularization, self).__init__(regularization_rate, biases)

    def compute_penalty(self, named_parameters, device):
        # Initialize with tensor, cannot be scalar
        penalty_term = torch.zeros(1, 1, requires_grad=True).to(device)

        for name, param in named_parameters:
            if not self.biases and name.endswith("bias"):
                pass
            else:
                penalty_term = penalty_term + param.norm(p=1)

        return self.rate * penalty_term

Notice how the penalty_term is moved to the given device. This is necessary in order to avoid operations with tensors stored at different devices.

Callbacks

Callbacks allow you to interact with the model during the fitting process. They provide with different methods that are called at different stages. To create a callback simply extend the base class and fill the desired methods.

import torch
from torchfitter.conventions import ParamsDict
from torchfitter.callbacks.base import Callback


class ModelCheckpoint(Callback):
    def __init__(self, path):
        super(ModelCheckpoint, self).__init__()

        self.path = path

    def __repr__(self) -> str:
        return "ModelCheckpoint()"

    def on_epoch_end(self, params_dict):
        # get params
        accelerator = params_dict[ParamsDict.ACCELERATOR]
        epoch = params_dict[ParamsDict.EPOCH_NUMBER]

        # ensure model is safe to save
        _model = params_dict[ParamsDict.MODEL]
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(_model)

        # actual saving
        fname = self.path / f'model_epoch_{epoch}.pt'
        accelerator.save(unwrapped_model.state_dict(), fname)

Each method receives params_dict, which is a dictionary object containing the internal training parameters. You can see the pair key value of each parameter of the conventions:

>>> from torchfitter.conventions import ParamsDict
>>> [(x, getattr(ParamsDict, x)) for x in ParamsDict.__dict__ if not x.startswith('__')]

And you can also check the doc to understand the meaning of each one of the parameters:

>>> from torchfitter.conventions import ParamsDict
>>> print(ParamsDict.__doc__)

NOTE: the callbacks design can be considered as a port from Keras design. I AM NOT the author of this callback sysem design despite the fact that I made some minor design changes. Find more in the Credits section.

FAQ

  • Do you know Pytorch-Lightning/FastAI?

I know them and I think they are awesome. This is a personal project though I must say the trainer is reasonably well-equiped.

  • Why is the validation loader not optional?

Because I think it enforces good ML practices that way.

  • Why didn't you implement the optimization steps in the model object?

It is certainly another approach you may take when building an optimization loop (PyTorch-Lightning works this way), but I don't like my abstract data types to track way too many things in addition to being torch.nn.Module types. Functionality should be clear and atomic: the model tracks gradients and the trainer cares about the optimization process.

  • I have a suggestion/question

Thank you! Do not hesitate to open an issue and I'll do my best to answer you.

CREDITS

Cite

If you've used this library for your projects please cite it:

@misc{alejandro2019torchfitter,
  title={torchfitter - Simple Trainer to Optimize PyTorch Models},
  author={Alejandro Pérez-Sanjuán},
  year={2020},
  howpublished={\url{https://github.com/Xylambda/torchfitter}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchfitter-4.1.2.tar.gz (43.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torchfitter-4.1.2-py3-none-any.whl (29.7 kB view details)

Uploaded Python 3

File details

Details for the file torchfitter-4.1.2.tar.gz.

File metadata

  • Download URL: torchfitter-4.1.2.tar.gz
  • Upload date:
  • Size: 43.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchfitter-4.1.2.tar.gz
Algorithm Hash digest
SHA256 59f1bfa234359a8ab527796684ee48c6f6c8cbad72ce46a19bcb2e184c20b1c2
MD5 c4f00650fcc743af6547be1a8a6e8346
BLAKE2b-256 b629f871f332784c93089e510901b7fe5c531a41e87e5905b3170f91c2fc68c0

See more details on using hashes here.

File details

Details for the file torchfitter-4.1.2-py3-none-any.whl.

File metadata

  • Download URL: torchfitter-4.1.2-py3-none-any.whl
  • Upload date:
  • Size: 29.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchfitter-4.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 59858b50068de76f7419a5f488c3389d7e68dc52aa4fe7f7b5e91d2d60b59c48
MD5 63d19802a2277221545d0edae3302f7f
BLAKE2b-256 d72105991847679a1530012792d5eff69215c8e743ebbe6c446eec62da304ba3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page