A minimal and simple machine learning experiment module for PyTorch.
Project description
torchplate
: Minimal Experiment Workflows in PyTorch
(Github | PyPI | Documentation)
An extremely minimal and simple experiment module for machine learning in PyTorch.
In addition to abstracting away the training loop, we provide several abstractions to improve the efficiency of machine learning workflows with PyTorch.
Example
To get started, create an experiment child class of torchplate.experiment.Experiment
and provide several key, experiment-unique items: model, optimizer, and a training set dataloader. Add whatever custom methods you may want to this class. Then starting training! That's it!
import torchplate
import data
import models
import torch
import torch.optim as optim
import torch.nn as nn
class SampleExp(torchplate.experiment.Experiment):
def __init__(self):
self.model = models.Net()
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
self.criterion = nn.CrossEntropyLoss()
dataset = data.load_set('cifar')
# use various torchplate.utils to improve efficiency of common workflows
self.trainloader, self.testloader = torchplate.utils.get_xy_loaders(dataset)
# inherit from torchplate.experiment.Experiment and pass in
# model, optimizer, and dataloader
super().__init__(
model = self.model,
optimizer = self.optimizer,
trainloader = self.trainloader
)
# provide this abstract method to calculate loss
def evaluate(self, batch):
x, y = batch
logits = self.model(x)
loss_val = self.criterion(logits, y)
return loss_val
exp = SampleExp()
exp.train(num_epochs=5)
output:
Epoch 1: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 293.98it/s]
Training Loss (epoch 1): 1.3564644632516083
Epoch 2: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 598.46it/s]
Training Loss (epoch 2): 1.2066593832439847
Epoch 3: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 579.40it/s]
Training Loss (epoch 3): 1.1030386642173484
Epoch 4: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 563.90it/s]
Training Loss (epoch 4): 1.0885229706764221
Epoch 5: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 577.54it/s]
Training Loss (epoch 5): 1.0520343957123932
Finished Training!
Installation
$ pip install torchplate
Changelog
0.0.2
- Fixed a polymorphic bug
0.0.1
- First version published. Provides basic data-loading utilities and the base experiment module.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchplate-0.0.2.tar.gz
.
File metadata
- Download URL: torchplate-0.0.2.tar.gz
- Upload date:
- Size: 2.8 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3862175f6ff0a8fa304f18aaa8ad98544c28cd0d973c9097c62e7ca40c3f2876 |
|
MD5 | 3d4e95826034b3b1daf69cbee9253c8c |
|
BLAKE2b-256 | 7c419406a06c608773b674302ffc9039d8748cd0922ff9cf5fa907b4359058cf |
File details
Details for the file torchplate-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: torchplate-0.0.2-py3-none-any.whl
- Upload date:
- Size: 5.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9f3a40eb0caed42b9a46f5c3db4781c02688539b512f48682b8a4a11ab50e575 |
|
MD5 | 774482baf6257aa664f89fe1cf44f493 |
|
BLAKE2b-256 | 7922e3045d93e2e04c7c452af845e18126abd561eff2082876e639e734277945 |