A pluggable & extensible trainer for pytorch
Project description
Torchero - A training framework for pytorch
Features
- Train/validate models for given number of epochs
- Hooks/Callbacks to add personalized behavior
- Different metrics of model accuracy/error
- Training/validation statistics monitors
- Cross fold validation iterators for splitting validation data from train data
Installation
From PyPI
pip install torchero
From Source Code
git clone https://github.com/juancruzsosa/torchero
cd torchero
python setup.py install
Example
Training with MNIST
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
import torchero
from torchero import SupervisedTrainer
from torchero.meters import CategoricalAccuracy
from torchero.callbacks import ProgbarLogger as Logger, CSVLogger
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.filter = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(2))
self.linear = nn.Sequential(nn.Linear(5*5*64, 500),
nn.BatchNorm1d(500),
nn.ReLU(inplace=True),
nn.Linear(500, 10))
def forward(self, x):
bs = x.shape[0]
return self.linear(self.filter(x).view(bs, -1))
train_ds = MNIST(root='data/',
download=True,
train=True,
transform=transforms.Compose([transforms.ToTensor()]))
test_ds = MNIST(root='data/',
download=False,
train=False,
transform=transforms.Compose([transforms.ToTensor()]))
train_dl = DataLoader(train_ds, batch_size=50)
test_dl = DataLoader(test_ds, batch_size=50)
model = Network()
trainer = SupervisedTrainer(model=model,
optimizer='sgd',
criterion='cross_entropy',
acc_meters={'acc': 'categorical_accuracy_percentage'},
callbacks=[Logger(),
CSVLogger(output='training_stats.csv')
])
# If you want to use cuda uncomment the next line
# trainer.cuda()
trainer.train(dataloader=train_dl,
valid_dataloader=test_dl,
epochs=2)
Trainers
BatchTrainer
: Abstract class for all trainers that works with batched inputsSupervisedTrainer
: Training for supervised tasksAutoencoderTrainer
: Trainer for auto encoder tasks
Callbacks
callbacks.Callback
: Base callback class for all epoch/training eventscallbacks.History
: Callback that record history of all training/validation metricscallbacks.Logger
: Callback that display metrics per logging stepcallbacks.ProgbarLogger
: Callback that displays progress bars to monitor training/validation metricscallbacks.CallbackContainer
: Callback to group multiple hookscallbacks.ModelCheckpoint
: Callback to save best model after every epochcallbacks.EarlyStopping
: Callback to stop training when monitored quanity not improvescallbacks.CSVLogger
: Callback that export training/validation stadistics to a csv file
Meters
meters.BaseMeter
: Interface for all metersmeters.BatchMeters
: Superclass of meters that works with batchsmeters.CategoricalAccuracy
: Meter for accuracy on categorical targetsmeters.BinaryAccuracy
: Meter for accuracy on binary targets (assuming normalized inputs)meters.BinaryAccuracyWithLogits
: Binary accuracy meter with an integrated activation function (by default logistic function)meters.ConfusionMatrix
: Meter for confusion matrix.meters.MSE
: Mean Squared Error metermeters.MSLE
: Mean Squared Log Error metermeters.RMSE
: Rooted Mean Squared Error metermeters.RMSLE
: Rooted Mean Squared Log Error metermeters.Precision
: Precision metermeters.Recall
: Precision metermeters.Specificity
: Precision metermeters.NPV
: Negative predictive value metermeters.F1Score
: F1 Score metermeters.F2Score
: F2 Score meter
Cross validation
utils.data.CrossFoldValidation
: Itererator through cross-fold-validation foldsutils.data.train_test_split
: Split dataset into train and test datasets
Datasets
utils.data.datasets.SubsetDataset
: Dataset that is a subset of the original datasetutils.data.datasets.ShrinkDatset
: Shrinks a datasetutils.data.datasets.UnsuperviseDataset
: Makes a dataset unsupervised
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchero-0.0.6.tar.gz
(40.9 kB
view details)
Built Distribution
torchero-0.0.6-py3-none-any.whl
(55.9 kB
view details)
File details
Details for the file torchero-0.0.6.tar.gz
.
File metadata
- Download URL: torchero-0.0.6.tar.gz
- Upload date:
- Size: 40.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f5482f6d2c8c72b302f918c7d38d6df32e30137caff725a6146532cdc487167e |
|
MD5 | def34224a874fa67cc09d66acffcf52b |
|
BLAKE2b-256 | 24d0c06be30d6d4d5a676ff1444764f71d0d84a784808d487cc5f71aad716a84 |
File details
Details for the file torchero-0.0.6-py3-none-any.whl
.
File metadata
- Download URL: torchero-0.0.6-py3-none-any.whl
- Upload date:
- Size: 55.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0f4d47fe591b3f89b118cdc1c49857ed623e835721c4951b569ce1ed61f5c9ea |
|
MD5 | 9469a430d4ec156a4550587dfa066ac4 |
|
BLAKE2b-256 | 98bdc38ea58d5f77d5d79dc1f634ebf623216f1f4e587fe0cde62c84c84cc83c |