Skip to main content

A pluggable & extensible trainer for pytorch

Project description

Torchero - A training framework for pytorch

Features

  • Train/validate models for given number of epochs
  • Hooks/Callbacks to add personalized behavior
  • Different metrics of model accuracy/error
  • Training/validation statistics monitors
  • Cross fold validation iterators for splitting validation data from train data

Example

Training with MNIST

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
import torchero
from torchero import SupervisedTrainer
from torchero.meters import CategoricalAccuracy
from torchero.callbacks import ProgbarLogger as Logger, CSVLogger

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.filter = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
                                    nn.ReLU(inplace=True),
                                    nn.BatchNorm2d(32),
                                    nn.MaxPool2d(2),
                                    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
                                    nn.ReLU(inplace=True),
                                    nn.BatchNorm2d(64),
                                    nn.MaxPool2d(2))
        self.linear = nn.Sequential(nn.Linear(5*5*64, 500),
                                    nn.BatchNorm1d(500),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(500, 10))

    def forward(self, x):
        bs = x.shape[0]
        return self.linear(self.filter(x).view(bs, -1))

train_ds = MNIST(root='data/',
                 download=True,
                 train=True,
                 transform=transforms.Compose([transforms.ToTensor()]))
test_ds = MNIST(root='data/',
                download=False,
                train=False,
                transform=transforms.Compose([transforms.ToTensor()]))
train_dl = DataLoader(train_ds, batch_size=args.batch_size)
test_dl = DataLoader(test_ds, batch_size=args.val_batch_size)

model = Network()

trainer = SupervisedTrainer(model=model,
                            optimizer='sgd',
                            criterion='cross_entropy',
                            logging_frecuency=args.logging_frecuency,
                            acc_meters={'acc': 'categorical_accuracy_percentage'},
                            callbacks=[Logger(),
                                       CSVLogger(output='training_stats.csv')
                                      ])
if args.use_cuda:
    trainer.cuda()

trainer.train(dataloader=train_dl,
              valid_dataloader=test_dl,
              epochs=args.epochs)

Trainers

  • BatchTrainer: Abstract class for all trainers that works with batched inputs
  • SupervisedTrainer: Training for supervised tasks
  • AutoencoderTrainer: Trainer for auto encoder tasks

Callbacks

  • callbacks.Callback: Base callback class for all epoch/training events
  • callbacks.History: Callback that record history of all training/validation metrics
  • callbacks.Logger: Callback that display metrics per logging step
  • callbacks.ProgbarLogger: Callback that displays progress bars to monitor training/validation metrics
  • callbacks.CallbackContainer: Callback to group multiple hooks
  • callbacks.ModelCheckpoint: Callback to save best model after every epoch
  • callbacks.EarlyStopping: Callback to stop training when monitored quanity not improves
  • callbacks.CSVLogger: Callback that export training/validation stadistics to a csv file

Meters

  • meters.BaseMeter: Interface for all meters
  • meters.BatchMeters: Superclass of meters that works with batchs
  • meters.CategoricalAccuracy: Meter for accuracy on categorical targets
  • meters.BinaryAccuracy: Meter for accuracy on binary targets (assuming normalized inputs)
  • meters.BinaryAccuracyWithLogits: Binary accuracy meter with an integrated activation function (by default logistic function)
  • meters.ConfusionMatrix: Meter for confusion matrix.
  • meters.MSE: Mean Squared Error meter
  • meters.MSLE: Mean Squared Log Error meter
  • meters.RMSE: Rooted Mean Squared Error meter
  • meters.RMSLE: Rooted Mean Squared Log Error meter

Cross validation

  • utils.data.CrossFoldValidation: Itererator through cross-fold-validation folds

Datasets

  • utils.data.datasets.SubsetDataset: Dataset that is a subset of the original dataset
  • utils.data.datasets.ShrinkDatset: Shrinks a dataset
  • utils.data.datasets.UnsuperviseDataset: Makes a dataset unsupervised

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchero-0.0.1.tar.gz (34.0 kB view details)

Uploaded Source

Built Distributions

torchero-0.0.1-py3.8.egg (126.7 kB view details)

Uploaded Source

torchero-0.0.1-py3-none-any.whl (47.6 kB view details)

Uploaded Python 3

File details

Details for the file torchero-0.0.1.tar.gz.

File metadata

  • Download URL: torchero-0.0.1.tar.gz
  • Upload date:
  • Size: 34.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.2

File hashes

Hashes for torchero-0.0.1.tar.gz
Algorithm Hash digest
SHA256 3a7f5185e75d152d340f433cc21b3415addf22a712577bfc062e11cba6575fb5
MD5 87345f5e1ec4decabf4b01e5268e5960
BLAKE2b-256 96404a1c7ed968298bff7f1c90458ba9175c29d481164bc9daa0faef699e79a3

See more details on using hashes here.

File details

Details for the file torchero-0.0.1-py3.8.egg.

File metadata

  • Download URL: torchero-0.0.1-py3.8.egg
  • Upload date:
  • Size: 126.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.2

File hashes

Hashes for torchero-0.0.1-py3.8.egg
Algorithm Hash digest
SHA256 5c00d0f22abd725497676c75c8f43c8a061d55193912851eceb6de272408d57f
MD5 6ea1c3a562bd0173384a1bec69405993
BLAKE2b-256 16c412d7a22befe57c37489e3807a674a69e402baad2cc28c693194e4965274d

See more details on using hashes here.

File details

Details for the file torchero-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torchero-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 47.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.3 requests-toolbelt/0.9.1 tqdm/4.45.0 CPython/3.8.2

File hashes

Hashes for torchero-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 39ac1edaab7ca120cc31f69e9c4381e189fc487b876795bcafc032a17cdf31cc
MD5 08c54e09001a5a5c5f17793697a74cf0
BLAKE2b-256 8a0ac430fd89e5b1c6f09bbc3f84f4c7e00c640d2de47f1d4b93a41c865ed510

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page