Trainer for Pytorch
Project description
Torch Runner
A minimal wrapper that removes some of the overhead code in training pytorch models
Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.
Requirements
- torch
- tqdm
Installation
pip install torch-runner
Features
- seed all variables
- text logger
- early stopping
- save hyperparameters
- weights & biases support
Example
Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.
import torch
import torch_runner as T
class myTrainer(T.TrainerModule):
def calc_metric(self, preds, target):
## Calc metrics such as accuracy etc.
def loss_fct(self, preds, target):
## Calc loss
def train_one_step(self, batch, batch_id):
## Get batch data from dataloader and perform one update
def valid_one_step(self, batch, batch_id):
## Perform validation step
config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader
Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_runner-0.1.0.tar.gz
(16.1 kB
view hashes)
Built Distribution
Close
Hashes for torch_runner-0.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ba98c78078f0d87e74bba3e718a01b794559c5fff1e1231b4579b64be3ed59b7 |
|
MD5 | 97200a11791eac71a2b305d8ab63be37 |
|
BLAKE2b-256 | 18521fcb9bc28b5445edf27da69d76e7b9356c9ca2f5e4b0b4aec5a3ec363eda |