Skip to main content

A toolkit for training pytorch models

Project description

Introduction

This package contains a Trainer class that streamlines the training of models and recording of results. The Trainer class is designed in a modular way using Mixins. This approach can be used to extend its capabilities beyond what it currently provides. Additionally, the class makes use of an eventing pattern that allows users to register event handlers that will be executed at specified points in training. The Trainer class can be found in trainer.py. All of the Mixins are stored in mixins.py. The module events.py contains the definitions of possible events and utils.py contains other miscellaneous code.

Installation

simply install from PyPi using pip install pt-trainer

Usage

Initialize a Trainer instance by passing a PyTorch model (inherited from nn.Module), PyTorch Dataloader instance, optimizer (PyTorch or apex) and a loss function that accepts the model prediction and targets and returns a loss tensor. Alternatively, a Trainer can be created from a config file. The config file should be another python file and contain the following variables:

  • MODEL: class of the model
  • DATASET: class of the dataset
  • LOSS: class of the loss function
  • OPTIMIZER: class of the optimizer
  • LOGDIR: path to the directory in which files generated by the trainer will be written
  • model: dict with kwargs for MODEL
  • dataset: dict with kwargs for DATASET
  • dataloader: dict with kwargs for the dataloader that will wrap DATASET
  • loss: dict with kwargs for LOSS
  • optimizer: dict with kwargs for OPTIMIZER
  • trainer: dict with kwargs for the Trainer class, such as split_sample

Optionally the APEX variable and apex dict can be specified to wrap the OPTIMIZER.

Once initialized, you can register event handlers using the method register_event_handler, specifying the handler and the event on which it will be called. There are four possible events: before training, each step, each epoch and after training.

Training is then executed using the train method and passing either n_epochs or n_steps.

Example

In the folder titled 'examples' I have set up a simple case of training a feed-forward neural net on a portion of MNIST. This examples illustrates how to setup the config and how to use the trainer. Try running dummy_training.py if you want to train the model.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pt-trainer-0.1.9.tar.gz (12.9 kB view details)

Uploaded Source

Built Distribution

pt_trainer-0.1.9-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file pt-trainer-0.1.9.tar.gz.

File metadata

  • Download URL: pt-trainer-0.1.9.tar.gz
  • Upload date:
  • Size: 12.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.1.0.post20200119 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6

File hashes

Hashes for pt-trainer-0.1.9.tar.gz
Algorithm Hash digest
SHA256 f7c49a1d7ac2c2c328e45b8f94c834f410ff9c5f6a2c1adda2ca8b41b3ce2603
MD5 2e400c62d294c95672059514419cc3fa
BLAKE2b-256 d26dbc94a20eadc6c1ff79a800b08232e0ef7d2ea5591c527ae3e760dafd6fc1

See more details on using hashes here.

File details

Details for the file pt_trainer-0.1.9-py3-none-any.whl.

File metadata

  • Download URL: pt_trainer-0.1.9-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/45.1.0.post20200119 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.6

File hashes

Hashes for pt_trainer-0.1.9-py3-none-any.whl
Algorithm Hash digest
SHA256 b55aeb3217c8148647afb13bd4d2502ec0f700aeddf539d8efabeee6af5cca02
MD5 b5828cd11560f04a89eb2e00e915d18a
BLAKE2b-256 55d828b2144779ca1800855e1ce8528fb9ca41aadaca38ed42ae29e8bd6496e1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page