Skip to main content

Trainer for Pytorch

Project description

Torch Runner

A minimal wrapper that removes some of the overhead code in training pytorch models

Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.

Requirements

  • torch
  • tqdm

Features

  • seed all variables
  • text logger
  • early stopping
  • save hyperparameters

Example

import torch
import torch_runner as T


class myTrainer(T.TrainerModule):
    def __init__(self, model, optimizer):
        super(myTrainer, self).__init__(model, optimizer)

    def calc_metric(self, preds, target):
        ## Calc metrics such as accuracy etc.

    def loss_fct(self, preds, target):
        ## Calc loss

    def train_one_step(self, batch, batch_id):
        ## Get batch data from dataloader and perform one update

    def valid_one_step(self, batch, batch_id):
        ## Perform validation step

model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader

Trainer = myTrainer(model, optimizer)
Trainer.fit(train_dataloader, val_dataloader, epochs=10, batch_size=32)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_runner-0.0.1.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

torch_runner-0.0.1-py3-none-any.whl (5.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_runner-0.0.1.tar.gz.

File metadata

  • Download URL: torch_runner-0.0.1.tar.gz
  • Upload date:
  • Size: 4.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.0.1.tar.gz
Algorithm Hash digest
SHA256 2f58ce676b7983d72373565df3f2613b693c4c226a4d270844467b50af6ca653
MD5 b65a7de6c08d17c22a820662763e57fe
BLAKE2b-256 a6bd479d8080ce7e61f7c300f20e7fd0ee70897b8c600e97ec86365211cb0e24

See more details on using hashes here.

File details

Details for the file torch_runner-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: torch_runner-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 5.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a2c4bc14b504d4681bf8b44cca6c875e4cce7034fe63ecdca3107790bb32cb63
MD5 01d3fed8b05c43aca21a30778c00e68c
BLAKE2b-256 07d21120d6a8a400599aa06a68fbbeb73e3851f1b779e7af56e6bb65e154e25c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page