Skip to main content

Trainer for Pytorch

Project description

Torch Runner

A minimal wrapper that removes some of the overhead code in training pytorch models

Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.

Requirements

  • torch
  • tqdm

Installation

pip install torch-runner

Features

  • seed all variables
  • text logger
  • early stopping
  • save hyperparameters

Example

Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.

import torch
import torch_runner as T


class myTrainer(T.TrainerModule):
    def __init__(self, model, optimizer):
        super(myTrainer, self).__init__(model, optimizer)

    def calc_metric(self, preds, target):
        ## Calc metrics such as accuracy etc.

    def loss_fct(self, preds, target):
        ## Calc loss

    def train_one_step(self, batch, batch_id):
        ## Get batch data from dataloader and perform one update

    def valid_one_step(self, batch, batch_id):
        ## Perform validation step

model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader

Trainer = myTrainer(model, optimizer)
Trainer.fit(train_dataloader, val_dataloader, epochs=10, batch_size=32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_runner-0.0.2.tar.gz (14.0 kB view details)

Uploaded Source

Built Distribution

torch_runner-0.0.2-py3-none-any.whl (6.0 kB view details)

Uploaded Python 3

File details

Details for the file torch_runner-0.0.2.tar.gz.

File metadata

  • Download URL: torch_runner-0.0.2.tar.gz
  • Upload date:
  • Size: 14.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.0.2.tar.gz
Algorithm Hash digest
SHA256 bd9722ec1c874502713eab15b6f374d03b5b9b7c2f9515c0de50702e66fbda8d
MD5 3f09c9e092d63c2113d636c324ca6bdd
BLAKE2b-256 5e997c825482e6e1f57bce6852a2cd44576de91ed13a41a00628a07edc91bb6a

See more details on using hashes here.

File details

Details for the file torch_runner-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: torch_runner-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 fb19e850530e3d6f503f2ee87bd937071edf9f24db9cc68a547e48b3dfae25ca
MD5 1f1b35a9ac39bf1008cc808ae8029d16
BLAKE2b-256 afb91850429cdf7cc0c77feae5aede4553a4979efe8529020b50d79944cd4ea6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page