Skip to main content

Trainer for Pytorch

Project description

Torch Runner

A minimal wrapper that removes some of the overhead code in training pytorch models

Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.

Requirements

  • torch
  • tqdm

Installation

pip install torch-runner

Features

  • seed all variables
  • text logger
  • early stopping
  • save hyperparameters
  • weights & biases support

Example

Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.

import torch
import torch_runner as T


class myTrainer(T.TrainerModule):
    def calc_metric(self, preds, target):
        ## Calc metrics such as accuracy etc.

    def loss_fct(self, preds, target):
        ## Calc loss

    def train_one_step(self, batch, batch_id):
        ## Get batch data from dataloader and perform one update

    def valid_one_step(self, batch, batch_id):
        ## Perform validation step

config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader

Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_runner-0.1.0.tar.gz (16.1 kB view hashes)

Uploaded Source

Built Distribution

torch_runner-0.1.0-py3-none-any.whl (6.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page