Trainer for Pytorch
Project description
Torch Runner
A minimal wrapper that removes some of the overhead code in training pytorch models
Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.
Requirements
- torch
- tqdm
Installation
pip install torch-runner
Features
- seed all variables
- text logger
- early stopping
- save hyperparameters
Example
Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.
import torch
import torch_runner as T
class myTrainer(T.TrainerModule):
def __init__(self, model, optimizer):
super(myTrainer, self).__init__(model, optimizer)
def calc_metric(self, preds, target):
## Calc metrics such as accuracy etc.
def loss_fct(self, preds, target):
## Calc loss
def train_one_step(self, batch, batch_id):
## Get batch data from dataloader and perform one update
def valid_one_step(self, batch, batch_id):
## Perform validation step
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader
Trainer = myTrainer(model, optimizer)
Trainer.fit(train_dataloader, val_dataloader, epochs=10, batch_size=32)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_runner-0.0.2.tar.gz
(14.0 kB
view details)
Built Distribution
File details
Details for the file torch_runner-0.0.2.tar.gz
.
File metadata
- Download URL: torch_runner-0.0.2.tar.gz
- Upload date:
- Size: 14.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | bd9722ec1c874502713eab15b6f374d03b5b9b7c2f9515c0de50702e66fbda8d |
|
MD5 | 3f09c9e092d63c2113d636c324ca6bdd |
|
BLAKE2b-256 | 5e997c825482e6e1f57bce6852a2cd44576de91ed13a41a00628a07edc91bb6a |
File details
Details for the file torch_runner-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: torch_runner-0.0.2-py3-none-any.whl
- Upload date:
- Size: 6.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fb19e850530e3d6f503f2ee87bd937071edf9f24db9cc68a547e48b3dfae25ca |
|
MD5 | 1f1b35a9ac39bf1008cc808ae8029d16 |
|
BLAKE2b-256 | afb91850429cdf7cc0c77feae5aede4553a4979efe8529020b50d79944cd4ea6 |