Skip to main content

A pluggable & extensible trainer for pytorch

Project description

Torchero - A training framework for pytorch

Torchero is a library that works on top of the PyTorch framework built to facilitate the training of Neural Networks.

GitHub Workflow Status codecov PyPI PyPI - Python Version license: MIT Documentation Status

Features

It provides tools and utilities to:

  • Set up a training process in a few lines of code.
  • Monitor the training performance by checking several prebuilt metrics on a handy progress bar.
  • Integrate a dashboard with TensorBoard to visualize those metrics in an online manner with a minimal setup.
  • Add custom functionality via Callbacks.
  • NLP: Datasets for text classification tasks. Vectors, etc.

Installation

From PyPI

pip install torchero

From Source Code

git clone https://github.com/juancruzsosa/torchero
cd torchero
python setup.py install

Quickstart - MNIST

In this case we are going to use torchero to train a Convolutional Neural Network for the MNIST Dataset.

Loading the Data

First, we load the dataset using torchvision. Then, if we want we can show the image samples using show_imagegrid_dataset

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST # The MNIST dataset
from torchvision import transforms # To convert to Images to Tensors

from torchero import SupervisedTrainer
from torchero.callbacks import ProgbarLogger, ModelCheckpoint, CSVLogger
from torchero.utils import show_imagegrid_dataset

from matplotlib import pyplot as plt

train_ds = MNIST(root='/tmp/data/mnist', download=True, train=True, transform=transforms.ToTensor())
test_ds = MNIST(root='/tmp/data/mnist', download=False, train=False, transform=transforms.ToTensor())

train_dl = DataLoader(train_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)

show_imagegrid_dataset(train_ds)
plt.show()

mnist images by class

Creating the Model

Then we have to define the model. For this case we can use a Sequential one.

model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
					  nn.ReLU(inplace=True),
					  nn.MaxPool2d(2),
					  nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
					  nn.ReLU(inplace=True),
					  nn.MaxPool2d(2),
					  nn.Flatten(),
					  nn.Linear(5*5*64, 500),
					  nn.ReLU(inplace=True),
					  nn.Linear(500, 10))

Training the Model

After the network and both train and test DataLoader's are defined. We can run the training using SDG optimizer, cross entropy Loss, categorical accuracy, a progress bar, a ModelCheckpoint to save the model when improves on accuracy, a CSVLogger to dump the metrics on a CSV file. We can call trainer.cuda() if we want to do training on GPU insted of CPU.

trainer = SupervisedTrainer(model=model,
						  optimizer='sgd',
						  criterion='cross_entropy',
						  acc_meters=['categorical_accuracy_percentage'],
						  callbacks=[ProgbarLogger(notebook=True),
									 ModelCheckpoint('saved_model', mode='max', monitor='val_acc'),
									 CSVLogger('training_results.xml')])

if torch.cuda.is_available():
	trainer.cuda()

trainer.train()
trainer.train(dataloader=train_dl, valid_dataloader=test_dl, epochs=5)

progress bar training

Showing the training results

To see the training metrics

fig, axs = plt.subplots(figsize=(14,3), ncols=2, nrows=1)
trainer.history.plot()
plt.show()

And if we want to see for example the confusion matrix on the test set.

results = trainer.evaluate(test_dl, ['categorical_accuracy_percentage', 'confusion_matrix'])
plt.figure(figsize=(10, 10))
results['confusion_matrix'].plot(classes=train_ds.classes)

confusion matrix

Documentation

Additional documentation can be founded here

Extensive List of Classes

Trainers

  • BatchTrainer: Abstract class for all trainers that works with batched inputs
  • SupervisedTrainer: Training for supervised tasks
  • AutoencoderTrainer: Trainer for auto encoder tasks

Callbacks

  • callbacks.Callback: Base callback class for all epoch/training events
  • callbacks.History: Callback that record history of all training/validation metrics
  • callbacks.Logger: Callback that display metrics per logging step
  • callbacks.ProgbarLogger: Callback that displays progress bars to monitor training/validation metrics
  • callbacks.CallbackContainer: Callback to group multiple hooks
  • callbacks.ModelCheckpoint: Callback to save best model after every epoch
  • callbacks.EarlyStopping: Callback to stop training when monitored quanity not improves
  • callbacks.CSVLogger: Callback that export training/validation stadistics to a csv file

Meters

  • meters.BaseMeter: Interface for all meters
  • meters.BatchMeters: Superclass of meters that works with batchs
  • meters.CategoricalAccuracy: Meter for accuracy on categorical targets
  • meters.BinaryAccuracy: Meter for accuracy on binary targets (assuming normalized inputs)
  • meters.BinaryAccuracyWithLogits: Binary accuracy meter with an integrated activation function (by default logistic function)
  • meters.ConfusionMatrix: Meter for confusion matrix.
  • meters.MSE: Mean Squared Error meter
  • meters.MSLE: Mean Squared Log Error meter
  • meters.RMSE: Rooted Mean Squared Error meter
  • meters.RMSLE: Rooted Mean Squared Log Error meter
  • meters.Precision: Precision meter
  • meters.Recall: Precision meter
  • meters.Specificity: Precision meter
  • meters.NPV: Negative predictive value meter
  • meters.F1Score: F1 Score meter
  • meters.F2Score: F2 Score meter

Cross validation

  • utils.data.CrossFoldValidation: Itererator through cross-fold-validation folds
  • utils.data.train_test_split: Split dataset into train and test datasets

Datasets

  • utils.data.datasets.SubsetDataset: Dataset that is a subset of the original dataset
  • utils.data.datasets.ShrinkDatset: Shrinks a dataset
  • utils.data.datasets.UnsuperviseDataset: Makes a dataset unsupervised

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchero-0.0.8.tar.gz (64.4 kB view details)

Uploaded Source

Built Distribution

torchero-0.0.8-py3-none-any.whl (83.0 kB view details)

Uploaded Python 3

File details

Details for the file torchero-0.0.8.tar.gz.

File metadata

  • Download URL: torchero-0.0.8.tar.gz
  • Upload date:
  • Size: 64.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for torchero-0.0.8.tar.gz
Algorithm Hash digest
SHA256 54c8d3281182bbd1906e3e1e2e5715d1ba9426457f77aae86eccf022d94f66d6
MD5 fbee5c9fea71738716563692f0129194
BLAKE2b-256 879df90b697b2b9616b6f46de5491941ef0e9cd341ec7bf825c540940bc3963c

See more details on using hashes here.

File details

Details for the file torchero-0.0.8-py3-none-any.whl.

File metadata

  • Download URL: torchero-0.0.8-py3-none-any.whl
  • Upload date:
  • Size: 83.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for torchero-0.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 f56eb42c00119779dae48a0eb6ab2b067e87179b5228efb7f4dcd4076cbd38f9
MD5 529ccaf8b172b0848c97c5bdc2926d0f
BLAKE2b-256 d621b8cae35b4294052ecf4903b26bc4e5e77431788403d9b0f522f521241b00

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page