A pluggable & extensible trainer for pytorch
Project description
Torchero - A training framework for pytorch
Torchero is a library that works on top of the PyTorch framework built to facilitate the training of Neural Networks.
Features
It provides tools and utilities to:
- Set up a training process in a few lines of code.
- Monitor the training performance by checking several prebuilt metrics on a handy progress bar.
- Integrate a dashboard with TensorBoard to visualize those metrics in an online manner with a minimal setup.
- Add custom functionality via Callbacks.
- NLP: Datasets for text classification tasks. Vectors, etc.
Installation
From PyPI
pip install torchero
From Source Code
git clone https://github.com/juancruzsosa/torchero
cd torchero
python setup.py install
Quickstart - MNIST
In this case we are going to use torchero to train a Convolutional Neural Network for the MNIST Dataset.
Loading the Data
First, we load the dataset using torchvision
. Then,
if we want we can show the image samples using show_imagegrid_dataset
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST # The MNIST dataset
from torchvision import transforms # To convert to Images to Tensors
from torchero import SupervisedTrainer
from torchero.callbacks import ProgbarLogger, ModelCheckpoint, CSVLogger
from torchero.utils import show_imagegrid_dataset
from matplotlib import pyplot as plt
train_ds = MNIST(root='/tmp/data/mnist', download=True, train=True, transform=transforms.ToTensor())
test_ds = MNIST(root='/tmp/data/mnist', download=False, train=False, transform=transforms.ToTensor())
train_dl = DataLoader(train_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
show_imagegrid_dataset(train_ds)
plt.show()
Creating the Model
Then we have to define the model. For this case we can use a Sequential one.
model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(5*5*64, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 10))
Training the Model
After the network and both train and test DataLoader's are defined.
We can run the training using SDG
optimizer, cross entropy Loss,
categorical accuracy, a progress bar, a ModelCheckpoint to
save the model when improves on accuracy, a CSVLogger to
dump the metrics on a CSV file. We can call trainer.cuda()
if we
want to do training on GPU insted of CPU.
trainer = SupervisedTrainer(model=model,
optimizer='sgd',
criterion='cross_entropy',
acc_meters=['categorical_accuracy_percentage'],
callbacks=[ProgbarLogger(notebook=True),
ModelCheckpoint('saved_model', mode='max', monitor='val_acc'),
CSVLogger('training_results.xml')])
if torch.cuda.is_available():
trainer.cuda()
trainer.train()
trainer.train(dataloader=train_dl, valid_dataloader=test_dl, epochs=5)
Showing the training results
To see the training metrics
fig, axs = plt.subplots(figsize=(14,3), ncols=2, nrows=1)
trainer.history.plot()
plt.show()
And if we want to see for example the confusion matrix on the test set.
results = trainer.evaluate(test_dl, ['categorical_accuracy_percentage', 'confusion_matrix'])
plt.figure(figsize=(10, 10))
results['confusion_matrix'].plot(classes=train_ds.classes)
Documentation
Additional documentation can be founded here
Extensive List of Classes
Trainers
BatchTrainer
: Abstract class for all trainers that works with batched inputsSupervisedTrainer
: Training for supervised tasksAutoencoderTrainer
: Trainer for auto encoder tasks
Callbacks
callbacks.Callback
: Base callback class for all epoch/training eventscallbacks.History
: Callback that record history of all training/validation metricscallbacks.Logger
: Callback that display metrics per logging stepcallbacks.ProgbarLogger
: Callback that displays progress bars to monitor training/validation metricscallbacks.CallbackContainer
: Callback to group multiple hookscallbacks.ModelCheckpoint
: Callback to save best model after every epochcallbacks.EarlyStopping
: Callback to stop training when monitored quanity not improvescallbacks.CSVLogger
: Callback that export training/validation stadistics to a csv file
Meters
meters.BaseMeter
: Interface for all metersmeters.BatchMeters
: Superclass of meters that works with batchsmeters.CategoricalAccuracy
: Meter for accuracy on categorical targetsmeters.BinaryAccuracy
: Meter for accuracy on binary targets (assuming normalized inputs)meters.BinaryAccuracyWithLogits
: Binary accuracy meter with an integrated activation function (by default logistic function)meters.ConfusionMatrix
: Meter for confusion matrix.meters.MSE
: Mean Squared Error metermeters.MSLE
: Mean Squared Log Error metermeters.RMSE
: Rooted Mean Squared Error metermeters.RMSLE
: Rooted Mean Squared Log Error metermeters.Precision
: Precision metermeters.Recall
: Precision metermeters.Specificity
: Precision metermeters.NPV
: Negative predictive value metermeters.F1Score
: F1 Score metermeters.F2Score
: F2 Score meter
Cross validation
utils.data.CrossFoldValidation
: Itererator through cross-fold-validation foldsutils.data.train_test_split
: Split dataset into train and test datasets
Datasets
utils.data.datasets.SubsetDataset
: Dataset that is a subset of the original datasetutils.data.datasets.ShrinkDatset
: Shrinks a datasetutils.data.datasets.UnsuperviseDataset
: Makes a dataset unsupervised
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.