A pluggable & extensible trainer for pytorch
Project description
Torchero - A training framework for pytorch
Torchero is a library that works on top of the PyTorch framework built to facilitate the training of Neural Networks.
Features
It provides tools and utilities to:
- Set up a training process in a few lines of code.
- Monitor the training performance by checking several prebuilt metrics on a handy progress bar.
- Integrate a dashboard with TensorBoard to visualize those metrics in an online manner with a minimal setup.
- Add custom functionality via Callbacks.
- NLP: Datasets for text classification tasks. Vectors, etc.
Installation
From PyPI
pip install torchero
From Source Code
git clone https://github.com/juancruzsosa/torchero
cd torchero
python setup.py install
Quickstart - MNIST
In this case we are going to use torchero to train a Convolutional Neural Network for the MNIST Dataset.
Loading the Data
First, we load the dataset using torchvision
. Then,
if we want we can show the image samples using show_imagegrid_dataset
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST # The MNIST dataset
from torchvision import transforms # To convert to Images to Tensors
from torchero import SupervisedTrainer
from torchero.callbacks import ProgbarLogger, ModelCheckpoint, CSVLogger
from torchero.utils import show_imagegrid_dataset
from matplotlib import pyplot as plt
train_ds = MNIST(root='/tmp/data/mnist', download=True, train=True, transform=transforms.ToTensor())
test_ds = MNIST(root='/tmp/data/mnist', download=False, train=False, transform=transforms.ToTensor())
train_dl = DataLoader(train_ds, batch_size=32)
test_dl = DataLoader(test_ds, batch_size=32)
show_imagegrid_dataset(train_ds)
plt.show()
Creating the Model
Then we have to define the model. For this case we can use a Sequential one.
model = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(5*5*64, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 10))
Training the Model
After the network and both train and test DataLoader's are defined.
We can run the training using SDG
optimizer, cross entropy Loss,
categorical accuracy, a progress bar, a ModelCheckpoint to
save the model when improves on accuracy, a CSVLogger to
dump the metrics on a CSV file. We can call trainer.cuda()
if we
want to do training on GPU insted of CPU.
trainer = SupervisedTrainer(model=model,
optimizer='sgd',
criterion='cross_entropy',
acc_meters=['categorical_accuracy_percentage'],
callbacks=[ProgbarLogger(notebook=True),
ModelCheckpoint('saved_model', mode='max', monitor='val_acc'),
CSVLogger('training_results.xml')])
if torch.cuda.is_available():
trainer.cuda()
trainer.train()
trainer.train(dataloader=train_dl, valid_dataloader=test_dl, epochs=5)
Showing the training results
To see the training metrics
fig, axs = plt.subplots(figsize=(14,3), ncols=2, nrows=1)
trainer.history.plot()
plt.show()
And if we want to see for example the confusion matrix on the test set.
results = trainer.evaluate(test_dl, ['categorical_accuracy_percentage', 'confusion_matrix'])
plt.figure(figsize=(10, 10))
results['confusion_matrix'].plot(classes=train_ds.classes)
Documentation
Additional documentation can be founded here
Extensive List of Classes
Trainers
BatchTrainer
: Abstract class for all trainers that works with batched inputsSupervisedTrainer
: Training for supervised tasksAutoencoderTrainer
: Trainer for auto encoder tasks
Callbacks
callbacks.Callback
: Base callback class for all epoch/training eventscallbacks.History
: Callback that record history of all training/validation metricscallbacks.Logger
: Callback that display metrics per logging stepcallbacks.ProgbarLogger
: Callback that displays progress bars to monitor training/validation metricscallbacks.CallbackContainer
: Callback to group multiple hookscallbacks.ModelCheckpoint
: Callback to save best model after every epochcallbacks.EarlyStopping
: Callback to stop training when monitored quanity not improvescallbacks.CSVLogger
: Callback that export training/validation stadistics to a csv file
Meters
meters.BaseMeter
: Interface for all metersmeters.BatchMeters
: Superclass of meters that works with batchsmeters.CategoricalAccuracy
: Meter for accuracy on categorical targetsmeters.BinaryAccuracy
: Meter for accuracy on binary targets (assuming normalized inputs)meters.BinaryAccuracyWithLogits
: Binary accuracy meter with an integrated activation function (by default logistic function)meters.ConfusionMatrix
: Meter for confusion matrix.meters.MSE
: Mean Squared Error metermeters.MSLE
: Mean Squared Log Error metermeters.RMSE
: Rooted Mean Squared Error metermeters.RMSLE
: Rooted Mean Squared Log Error metermeters.Precision
: Precision metermeters.Recall
: Precision metermeters.Specificity
: Precision metermeters.NPV
: Negative predictive value metermeters.F1Score
: F1 Score metermeters.F2Score
: F2 Score meter
Cross validation
utils.data.CrossFoldValidation
: Itererator through cross-fold-validation foldsutils.data.train_test_split
: Split dataset into train and test datasets
Datasets
utils.data.datasets.SubsetDataset
: Dataset that is a subset of the original datasetutils.data.datasets.ShrinkDatset
: Shrinks a datasetutils.data.datasets.UnsuperviseDataset
: Makes a dataset unsupervised
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchero-0.0.8.tar.gz
.
File metadata
- Download URL: torchero-0.0.8.tar.gz
- Upload date:
- Size: 64.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54c8d3281182bbd1906e3e1e2e5715d1ba9426457f77aae86eccf022d94f66d6 |
|
MD5 | fbee5c9fea71738716563692f0129194 |
|
BLAKE2b-256 | 879df90b697b2b9616b6f46de5491941ef0e9cd341ec7bf825c540940bc3963c |
File details
Details for the file torchero-0.0.8-py3-none-any.whl
.
File metadata
- Download URL: torchero-0.0.8-py3-none-any.whl
- Upload date:
- Size: 83.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f56eb42c00119779dae48a0eb6ab2b067e87179b5228efb7f4dcd4076cbd38f9 |
|
MD5 | 529ccaf8b172b0848c97c5bdc2926d0f |
|
BLAKE2b-256 | d621b8cae35b4294052ecf4903b26bc4e5e77431788403d9b0f522f521241b00 |