Trainer for Pytorch
Project description
Torch Runner
A minimal wrapper that removes some of the overhead code in training pytorch models
Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.
Requirements
- torch
- tqdm
Installation
pip install torch-runner
Features
- seed all variables
- text logger
- early stopping
- save hyperparameters
- weights & biases support
Example
Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.
import torch
import torch_runner as T
class myTrainer(T.TrainerModule):
def calc_metric(self, preds, target):
## Calc metrics such as accuracy etc.
def loss_fct(self, preds, target):
## Calc loss
def train_one_step(self, batch, batch_id):
## Get batch data from dataloader and perform one update
def valid_one_step(self, batch, batch_id):
## Perform validation step
config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader
Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torch_runner-0.1.0.tar.gz
(16.1 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_runner-0.1.0.tar.gz.
File metadata
- Download URL: torch_runner-0.1.0.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
eda717970d740391978820327b0515f0e06a733dee3088e230a34d92ae6a7758
|
|
| MD5 |
17be81c5c8b8ac2b9aac3a435271d140
|
|
| BLAKE2b-256 |
36f16bb57b7682d2d6599bec872e432cbd9daff38547db68cba67bee4e4c4864
|
File details
Details for the file torch_runner-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_runner-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ba98c78078f0d87e74bba3e718a01b794559c5fff1e1231b4579b64be3ed59b7
|
|
| MD5 |
97200a11791eac71a2b305d8ab63be37
|
|
| BLAKE2b-256 |
18521fcb9bc28b5445edf27da69d76e7b9356c9ca2f5e4b0b4aec5a3ec363eda
|