Skip to main content

fitloop trains Pytorch models

Project description

➰ fitloop

fitloop trains Pytorch models


What's fitloop

fitloop is a substitute to having to write the boilerplate code pertaining to the training loop with some added niceties.

Installation

$ pip install fitloop

Sample

Training a model using fitloop for 10 epochs.

from fitloop import FitLoop

fdict = {
    "model": model,
    "loss_function": nn.CrossEntropyLoss(),
    "optimizer": Adam(model.parameters()),
    "train_dl": train_dl,
}
trainer = FitLoop(**fdict)
trainer.fit(epochs=10)

Demo

COLAB Links for training a simple CNN on FashionMNIST.


Intro

Stages

A training loop can be divided into stages. Here it has been divided into three.

  1. Batch Step
    • This is the stage where forward pass occurs model(X).
    • If the loop is in train phase then loss.backward and optimizer.step have to be called.
  2. Epoch End
    • Stage at the end of the loop where metrics for model evaluation are calculated.
  3. Epoch Start
    • Stage at the start of the loop where (if required) pretrain initializations can be executed.

Phases

Model training and evaluation consists of three phases.

  1. Training
  2. Validation
  3. Testing

Each of the three phases can have their own distinct stages where different or same computations are performed.

This is what fitloop let's you do i.e. define the stage functions for the three phases. Besides this most can be handled by fitloop If the use case is simple then even this is not required, there are predefined stage functions that will work for simple use cases.

Examples

  • For a simple example check out Basic Usage.ipynb
  • For a more comprehensive example showcasing more of fitloops features check out Usage.ipynb

Setup

For the most basic usage of fitloop the few things that need to be defined are.

  • model - the Pytorch model that needs to be trained.
  • optimizer - Pytorch optimizer that is used to optimize the model.
  • loss_function - A loss function for computing.. you guessed it, loss.
  • train_dl - DataLoader for training the model.

For extending the usage beyond basic a few more additional components are required.

  • valid_dl - DataLoader for validation of the model.
  • test_dl - DataLoader for testing the model.
  • stage functions - A set of functions that are called at different stages throughout the loop.
  • criteria - name of the validation criteria that is used to evaluate the model, this is a value that is returned from the state dict.
  • configure_optimizer - A function that is used to configure the optimizer parameters so that the model is ready to train when it is loaded or restored, the default one can only be used for single optimizers and single lr_schedulers.

Configure Optimizer

configure_optimizer(floop:FitLoop) -> None
  • This function is required so that when the model is reset to it's last best state during training the optimizer parameter groups can be reset.
  • If this function isn't set, FitLoop.optimizer parameter groups will have to be set manually, else the weights won't update on further training.
  • Model weights are reset:
    • After training if load_best arg is True in the FitLoop.fit function.
    • When FitLoop.reset is called.
    • When FitLoop.run_sanity_check or FitLoop.run_profiler are called.
  • configure_optimizer can be called using FitLoop.configure_optimizer.

Stage Functions

There are nine stage functions (three for each of the three phases) that are called in the loop.

phase_stagefunction(state:LoopState) -> Dict[str,float]

BatchStep

  • train_step - should call loss.backward and optimizer.step and return metrics to be tracked.
  • valid_step, test_step - should return metrics to be tracked.
  • test_step - should return metrics to be tracked.

Epoch End Step

  • train_epoch_end, valid_epoch_end,test_epoch_end - calculate required metrics such as loss from values returned in the batch step.

Epoch Start Step

  • train_epoch_start, valid_epoch_start,test_epoch_start - Can be used for initilizations.

Example stage functions

Batch Step stage function for train phase

def train_step(state):
    X, y = state.batch
    y_ = state.model(X)

    loss = state.loss_function(y_, y)

    # Calculate gradients and backprop
    loss.backward()
    state.optimizer.step()

    running_loss = loss.item() * state.batch_size
    running_corr = (y_.argmax(dim=1)==y).sum().float().item()

    return {
      "running_loss": running_loss,
      "running_corr": running_corr
    }

Epoch End Step stage function for train phase

def train_epoch_end(state):
    # Values returned are pytorch Tensors
    running_losses = state['running_loss']
    running_corrs = state['running_corr']

    # state.size : number of samples in the dataset
    loss = running_losses.sum().item() / state.size
    accuracy = running_corrs.sum().item() / state.size

    return {
      "loss": loss,
      "accuracy": accuracy
    }

Creating a fitloop trainer

Once all the components have been set up a fitloop trainer can be set up like so

fdict = {
    "model":model,
    "loss_function":loss_function,
    "optimizer":optimizer,

    # DataLoaders
    "train_dl":train_dl,
    "valid_dl":valid_dl,
    "test_dl":test_dl,

    # Batch Step Stage Functions
    "train_step":train_step,
    "valid_step":valid_step,
    "test_step":test_step,

    # Epoch End Stage Functions
    "train_epoch_end": train_epoch_end,
    "valid_epoch_end": valid_epoch_end,
    "test_epoch_end": test_epoch_end,

    # Model Evaluation
    "criteria": "accuracy", # Returned in valid_epoch_end stage function dict

    # Param Restoration/Update
    "configure_optimizer":configure_optimizer,
}

trainer = FitLoop(**fdict)

Usage

Training

model can be trained using FitLoop.fit, in the below example the model is being trained for 4 epochs, after every 2 epochs it will ask whether to continue training. training

Metrics

All metrics returned from the stage functions can be accessed using FitLoop.M

FitLoop.plot(metric_name) can be used to plot the required metrics that have been returned from the train_epoch_end and valid_epoch_end stage functions. plotting

Testing

Calling FitLoop.test() will run the loop in test mode ie for one epoch with the stage functions defined for testing.


For additional functionalities such as loading and restoring models, resetting, sanity checks, profiler check Usage.ipynb


TODO

  • Create better documentation.
  • Update examples.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

fitloop-0.1.5.tar.gz (23.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fitloop-0.1.5-py3-none-any.whl (25.2 kB view details)

Uploaded Python 3

File details

Details for the file fitloop-0.1.5.tar.gz.

File metadata

  • Download URL: fitloop-0.1.5.tar.gz
  • Upload date:
  • Size: 23.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200325 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.6

File hashes

Hashes for fitloop-0.1.5.tar.gz
Algorithm Hash digest
SHA256 1e758a6cc3ecb263d0fc2b9bda70d44f0f7cebb9c16521b7fb07fa5650150581
MD5 72b3748878b9beec94965260d4508780
BLAKE2b-256 53101d636251d3c96924da5c6f434727ac7a31c4386695a775dddfd587246caa

See more details on using hashes here.

File details

Details for the file fitloop-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: fitloop-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 25.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3.post20200325 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.7.6

File hashes

Hashes for fitloop-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 c91246a082aae9c541760d3c4f85821e592c689932a0389ba27b36bba1865e6b
MD5 91961908e415b1d5d81676b831f11a7b
BLAKE2b-256 c76550810569066256066ce81b8100189b61952a665466d79dce6bf959f2b40a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page