fitloop trains Pytorch models
Project description
➰ fitloop
fitloop
trains Pytorch models
What's fitloop
fitloop
is a substitute to having to write the boilerplate code associated with writing the training loop with some added niceties.
Sample code using fitloop
to train a model for 10 epochs.
from fitloop import FitLoop
def configure_optimizer(floop):
floop.optimizer.param_groups.clear()
floop.optimizer.add_param_group({
'params':floop.model.parameters()
})
params = model.parameters()
fdict = {
"model": model,
"loss_function": nn.CrossEntropyLoss(),
"optimizer": Adam(params),
"train_dl": train_dl,
"configure_optimizer": configure_optimizer
}
trainer = FitLoop(**fdict)
trainer.fit(epochs=10)
Intro
Stages
A training loop can be divided into stages. Here it has been divided into three.
- Batch Step
- This is the stage where forward pass occurs
model(X)
. - If the loop is in
train
phase thenloss.backward
andoptimizer.step
have to be called.
- This is the stage where forward pass occurs
- Epoch End
- Stage at the end of the loop where metrics for model evaluation are calculated.
- Epoch Start
- Stage at the start of the loop where (if required) pretrain initializations can be executed.
Phases
Model training and evaluation consists of three phases.
- Training
- Validation
- Testing
Each of the three phases can have their own distinct stages where different or same computations are performed.
This is what fitloop
let's you do i.e. define the stage functions for the three phases.
Besides this most can be handled by fitloop
If the use case is simple then even this is not required, there are predefined stage functions that will work for simple use cases.
Examples
- For a simple example check out Basic Usage.ipynb
- For a more comprehensive example showcasing more of
fitloop
s features check out Usage.ipynb
Setup
For the most basic usage of fitloop
the few things that need to be defined are.
model
- the Pytorch model that needs to be trained.optimizer
- Pytorch optimizer that is used to optimize the model.loss_function
- A loss function for computing.. you guessed it, loss.configure_optimizer
- A function that is used to configure the optimizer parameters so that the model is ready to train when it is loaded or restored.train_dl
- DataLoader for training the model.
For extending the usage beyond basic a few more additional components are required.
valid_dl
- DataLoader for validation of the model.test_dl
- DataLoader for testing the model.- stage functions - A set of functions that are called at different stages throughout the loop.
criteria
- name of the validation criteria that is used to evaluate the model, this is a value that is returned from the state dict.
Configure Optimizer
configure_optimizer(floop:FitLoop) -> None
- This function is required so that when the model is reset to it's last best state during training the optimizer parameter groups can be reset.
- If this function isn't set,
FitLoop.optimizer
parameter groups will have to be set manually, else the weights won't update on further training. - Model weights are reset:
- After training if
load_best
arg isTrue
in theFitLoop.fit
function. - When
FitLoop.reset
is called. - When
FitLoop.run_sanity_check
orFitLoop.run_profiler
are called.
- After training if
configure_optimizer
can be called usingFitLoop.configure_optimizer
.
Stage Functions
There are nine stage functions (three for each of the three phases) that are called in the loop.
phase_stagefunction(state:LoopState) -> Dict[str,float]
BatchStep
train_step
- should callloss.backward
andoptimizer.step
and return metrics to be tracked.valid_step
,test_step
- should return metrics to be tracked.test_step
- should return metrics to be tracked.
Epoch End Step
train_epoch_end
,valid_epoch_end
,test_epoch_end
- calculate required metrics such asloss
from values returned in the batch step.
Epoch Start Step
train_epoch_start
,valid_epoch_start
,test_epoch_start
- Can be used for initilizations.
Example stage functions
Batch Step stage function for train phase
def train_step(state):
X, y = state.batch
y_ = state.model(X)
loss = state.loss_function(y_, y)
# Calculate gradients and backprop
loss.backward()
state.optimizer.step()
running_loss = loss.item() * state.batch_size
running_corr = (y_.argmax(dim=1)==y).sum().float().item()
return {
"running_loss": running_loss,
"running_corr": running_corr
}
Epoch End Step stage function for train phase
def train_epoch_end(state):
# Values returned are pytorch Tensors
running_losses = state['running_loss']
running_corrs = state['running_corr']
# state.size : number of samples in the dataset
loss = running_losses.sum().item() / state.size
accuracy = running_corrs.sum().item() / state.size
return {
"loss": loss,
"accuracy": accuracy
}
Creating a fitloop
trainer
Once all the components have been set up a fitloop
trainer can be set up like so
fdict = {
"model":model,
"loss_function":loss_function,
"optimizer":optimizer,
# DataLoaders
"train_dl":train_dl,
"valid_dl":valid_dl,
"test_dl":test_dl,
# Batch Step Stage Functions
"train_step":train_step,
"valid_step":valid_step,
"test_step":test_step,
# Epoch End Stage Functions
"train_epoch_end": train_epoch_end,
"valid_epoch_end": valid_epoch_end,
"test_epoch_end": test_epoch_end,
# Model Evaluation
"criteria": "accuracy", # Returned in valid_epoch_end stage function dict
# Model Preservation
"save_to_disk": True # Will save pretrained and best model to disk
# Param Restoration/Update
"configure_optimizer":configure_optimizer,
}
trainer = FitLoop(**fdict)
Usage
Training
model can be trained using FitLoop.fit
, in the below example the model is being trained for 4 epochs, after every 2 epochs it will ask whether to continue training.
Metrics
All metrics returned from the stage functions can be accessed using FitLoop.M
FitLoop.plot(metric_name)
can be used to plot the required metrics that have been returned from the train_epoch_end
and valid_epoch_end
stage functions.
Testing
Calling FitLoop.test()
will run the loop in test mode ie for one epoch with the stage functions defined for testing.
For additional functionalities such as loading and restoring models, resetting, sanity checks, profiler check Usage.ipynb
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.