Trainer for Pytorch
Project description
Torch Runner
A minimal wrapper that removes some of the overhead code in training pytorch models
Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.
Requirements
- torch
- tqdm
Installation
pip install torch-runner
Features
- seed all variables
- text logger
- early stopping
- save hyperparameters
- weights & biases support
Example
Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.
import torch
import torch_runner as T
class myTrainer(T.TrainerModule):
def calc_metric(self, preds, target):
## Calc metrics such as accuracy etc.
def loss_fct(self, preds, target):
## Calc loss
def train_one_step(self, batch, batch_id):
## Get batch data from dataloader and perform one update
def valid_one_step(self, batch, batch_id):
## Perform validation step
config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader
Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_runner-0.1.0.tar.gz
(16.1 kB
view details)
Built Distribution
File details
Details for the file torch_runner-0.1.0.tar.gz
.
File metadata
- Download URL: torch_runner-0.1.0.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | eda717970d740391978820327b0515f0e06a733dee3088e230a34d92ae6a7758 |
|
MD5 | 17be81c5c8b8ac2b9aac3a435271d140 |
|
BLAKE2b-256 | 36f16bb57b7682d2d6599bec872e432cbd9daff38547db68cba67bee4e4c4864 |
File details
Details for the file torch_runner-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: torch_runner-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ba98c78078f0d87e74bba3e718a01b794559c5fff1e1231b4579b64be3ed59b7 |
|
MD5 | 97200a11791eac71a2b305d8ab63be37 |
|
BLAKE2b-256 | 18521fcb9bc28b5445edf27da69d76e7b9356c9ca2f5e4b0b4aec5a3ec363eda |