Skip to main content

A pluggable & extensible trainer for pytorch

Project description

Torchero - A training framework for pytorch

Torchero is a library that works on top of the PyTorch framework built to facilitate the training of Neural Networks.

GitHub Workflow Status codecov PyPI PyPI - Python Version license: MIT Documentation Status

Features

It provides tools and utilities to:

  • Set up a training process in a few lines of code.
  • Monitor the training performance by checking several prebuilt metrics on a handy progress bar.
  • Integrate a TensorBoard dashboard to visualize those metrics in an online manner with a minimal setup.
  • Add custom functionality via Callbacks.
  • NLP & Computer Vision: Datasets for text and image classification tasks. Pretrained Embedding Vectors, Models, etc.

Installation

From PyPI

pip install torchero

From Source Code

git clone https://github.com/juancruzsosa/torchero
cd torchero
python setup.py install

Quickstart - MNIST

Loading the Data

import torch
from torch import nn

import torchero
from torchero.models.vision import ImageClassificationModel
from torchero.callbacks import ProgbarLogger, ModelCheckpoint, CSVLogger
from torchero.utils.data import train_test_split
from torchero.utils.vision import show_imagegrid_dataset, transforms, datasets, download_image
from torchero.meters import ConfusionMatrix

from matplotlib import pyplot as plt

First we load the MNIST train and test datasets and visualize it using show_imagegrid_dataset. The Data Augmentation for this case will be a RandomInvert to flip the grayscale levels.

train_ds = datasets.MNIST(root='/tmp/data/mnist', download=True, train=True, transform=transforms.Compose([transforms.RandomInvert(),
                                                                                                  transforms.ToTensor()]))
test_ds = datasets.MNIST(root='/tmp/data/mnist', download=False, train=False, transform=transforms.ToTensor())

show_imagegrid_dataset(train_ds)
plt.show()

mnist images by class

Defining the Network

Let's define a Convolutional network of two layers followed by a Linear Module as the classification layer.

network = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
					  nn.ReLU(inplace=True),
					  nn.MaxPool2d(2),
					  nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
					  nn.ReLU(inplace=True),
					  nn.MaxPool2d(2),
					  nn.Flatten(),
					  nn.Linear(5*5*64, 500),
					  nn.ReLU(inplace=True),
					  nn.Linear(500, 10))

Training the Model

The ImageClassificationModel is the module responsible to train the model, evaluate a metric against a dataset, and predict from and input for multi-class classification tasks.

To train the model we need to compile it first with a:

  • an optimizer: 'adam'
  • a loss which will be defaulted to cross_entropy
  • a list of metrics which will be defaulted to categorical_accuracy, balanced_accuracy)
  • a list of callbacks:
    • ProgbarLogger to show training progress bar
    • ModelCheckpoint to make checkpoints if the model improves
    • CSVLogger to dump the metrics to a csv file after each epoch
model = ImageClassificationModel(model=network, 
                                 transform=transforms.Compose([transforms.Grayscale(),
                                                               transforms.Resize((28,28)),
                                                               transforms.ToTensor()]),
                                 classes=[str(i) for i in range(10)])
model.compile(optimizer='adam',
              callbacks=[ProgbarLogger(notebook=True),
                         ModelCheckpoint('saved_model', mode='max', monitor='val_acc'),
                         CSVLogger('training_results.xml')])

if torch.cuda.is_available():
    model.cuda()

history = model.fit(train_ds,
                    test_ds,
                    batch_size=1024,
                    epochs=5)

progress bar training

Displaying the training results

To visualize our loss and accuracy in each epoch we can execute:

history.plot(figsize=(20, 20), smooth=0.2)
plt.show()

The .evaluate returns the metrics for a new dataset.

results = model.evaluate(test_ds, metrics=['categorical_accuracy', 'balanced_accuracy', ConfusionMatrix()])
for metric in ['acc', 'balanced_acc']:
    print("{}: {:.3f}%".format(metric, results[metric] * 100))
fig, ax = plt.subplots(figsize=(12,12))
results['confusion_matrix'].plot(fig=fig, ax=ax, classes=model.classes)

confusion matrix

Documentation

Additional documentation can be founded here

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchero-0.1.0.tar.gz (79.2 kB view details)

Uploaded Source

Built Distribution

torchero-0.1.0-py3-none-any.whl (103.1 kB view details)

Uploaded Python 3

File details

Details for the file torchero-0.1.0.tar.gz.

File metadata

  • Download URL: torchero-0.1.0.tar.gz
  • Upload date:
  • Size: 79.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.9.6

File hashes

Hashes for torchero-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e12a492eaf241fd3e0d85968f956c3924cf5efb8bcc3fa77ccd6b6a5f331a7c9
MD5 b6aea254c17e4036dce4c6dfc34e9e40
BLAKE2b-256 03c361d6073702e2f69a95dada3039f0e6df68c9546cc2ac40a77704cc57f250

See more details on using hashes here.

File details

Details for the file torchero-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchero-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 103.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.6.3 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.0 CPython/3.9.6

File hashes

Hashes for torchero-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 aab43894fdd0edb5498ab92efe2d748a92bf04d6f907955f81565efbdc58352b
MD5 4ce429bf2e1ea1e7cde6050d57666100
BLAKE2b-256 149d61c2c8a2c47e53cbb38df2a3bcb9d29a20ee40f2b30bf720364e9bca0ef8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page