Skip to main content

Trainer for Pytorch

Project description

Torch Runner

A minimal wrapper that removes some of the overhead code in training pytorch models

Note: If you are looking for something more extensive, checkout Pytorch Lightning. This is mostly designed for my personal use.

Requirements

  • torch
  • tqdm

Installation

pip install torch-runner

Features

  • seed all variables
  • text logger
  • early stopping
  • save hyperparameters
  • weights & biases support

Example

Checkout the examples folder which contains a jupyter notebook to train a resnet50 using torch_runner.

import torch
import torch_runner as T


class myTrainer(T.TrainerModule):
    def calc_metric(self, preds, target):
        ## Calc metrics such as accuracy etc.

    def loss_fct(self, preds, target):
        ## Calc loss

    def train_one_step(self, batch, batch_id):
        ## Get batch data from dataloader and perform one update

    def valid_one_step(self, batch, batch_id):
        ## Perform validation step

config = T.TrainerConfig()
model = myModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

train_dataloader = ## pytorch dataloader
val_dataloader = ## pytorch dataloader

Trainer = myTrainer(model, optimizer, config)
Trainer.fit(train_dataloader, val_dataloader, epochs=10)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_runner-0.1.0.tar.gz (16.1 kB view details)

Uploaded Source

Built Distribution

torch_runner-0.1.0-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file torch_runner-0.1.0.tar.gz.

File metadata

  • Download URL: torch_runner-0.1.0.tar.gz
  • Upload date:
  • Size: 16.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.1.0.tar.gz
Algorithm Hash digest
SHA256 eda717970d740391978820327b0515f0e06a733dee3088e230a34d92ae6a7758
MD5 17be81c5c8b8ac2b9aac3a435271d140
BLAKE2b-256 36f16bb57b7682d2d6599bec872e432cbd9daff38547db68cba67bee4e4c4864

See more details on using hashes here.

File details

Details for the file torch_runner-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torch_runner-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.12.5 setuptools/49.6.0.post20200925 requests-toolbelt/0.9.1 tqdm/4.50.0 CPython/3.6.12

File hashes

Hashes for torch_runner-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ba98c78078f0d87e74bba3e718a01b794559c5fff1e1231b4579b64be3ed59b7
MD5 97200a11791eac71a2b305d8ab63be37
BLAKE2b-256 18521fcb9bc28b5445edf27da69d76e7b9356c9ca2f5e4b0b4aec5a3ec363eda

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page