Skip to main content

A minimal and simple machine learning experiment module for PyTorch.

Project description

torchplate: Minimal Experiment Workflows in PyTorch

(Github | PyPI | Documentation)

Installation | Example | More examples | Starter project | Changelog

An extremely minimal and simple experiment module for machine learning in PyTorch (PyTorch + boilerplate = torchplate).

In addition to abstracting away the training loop, we provide several abstractions to improve the efficiency of machine learning workflows with PyTorch.

Installation

$ pip install torchplate

Example

To get started, create a child class of torchplate.experiment.Experiment and provide several key, experiment-unique items: model, optimizer, and a training set dataloader. Then, provide an implementation of the abstract method evaluate. This function takes in a batch from the trainloader and should return the loss (i.e., implement the forward pass + loss calculation). Add whatever custom methods you may want to this class. Then starting training! That's it!

import torchplate
import data 
import models
import torch
import torch.optim as optim
import torch.nn as nn


class SampleExp(torchplate.experiment.Experiment):
    def __init__(self): 
        self.model = models.Net()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        dataset = data.load_set('cifar')
        # use various torchplate.utils to improve efficiency of common workflows 
        self.trainloader, self.testloader = torchplate.utils.get_xy_loaders(dataset)

        # inherit from torchplate.experiment.Experiment and pass in
        # model, optimizer, and dataloader 
        super().__init__(
            model = self.model,
            optimizer = self.optimizer,
            trainloader = self.trainloader 
        )
    
    # provide this abstract method to calculate loss 
    def evaluate(self, batch):
        x, y = batch
        logits = self.model(x)
        loss_val = self.criterion(logits, y)
        return loss_val


exp = SampleExp()
exp.train(num_epochs=10, gradient_accumulate_every_n_batches=4, display_batch_loss=False)

output:

Epoch 1: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 293.98it/s]
Training Loss (epoch 1): 1.3564644632516083
Epoch 2: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 598.46it/s]
Training Loss (epoch 2): 1.2066593832439847
Epoch 3: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 579.40it/s]
Training Loss (epoch 3): 1.1030386642173484
Epoch 4: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 563.90it/s]
Training Loss (epoch 4): 1.0885229706764221
Epoch 5: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 577.54it/s]
Training Loss (epoch 5): 1.0520343957123932
Finished Training!

More examples

See examples/cifar for another minimal example. See examples/starter for a full program example. To get started running your own experiments, you can use examples/starter as a base (or use cookiecutter as shown below).

Starter project

The starter branch holds the source for a cookiecutter project. This allows users to easily create projects from the starter code example by running a simple command. To get started, install cookiecutter and then type

$ cookiecutter https://github.com/rosikand/torchplate.git --checkout starter

which will generate the following structure for you to use as a base for your projects:

torchplate_starter
├── datasets.py
├── experiments.py
├── models.py
└── runner.py

Changelog

0.0.7

  • Largest change to date. New features: gradient accumulation, save weights every $n$ epochs, display batch loss, metrics, metrics interfaced with train.

0.0.6

  • Fixed bug in model weight saving.

0.0.5

  • Added model weights loading and saving.

0.0.4

  • Several changes: added callbacks, changed verbose default to true, added ModelInterface pipeline to utils.

0.0.3

  • Added verbose option as well as wandb logging

0.0.2

  • Fixed a polymorphic bug

0.0.1

  • First version published. Provides basic data-loading utilities and the base experiment module.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchplate-0.0.7.tar.gz (2.8 MB view details)

Uploaded Source

Built Distribution

torchplate-0.0.7-py3-none-any.whl (9.4 kB view details)

Uploaded Python 3

File details

Details for the file torchplate-0.0.7.tar.gz.

File metadata

  • Download URL: torchplate-0.0.7.tar.gz
  • Upload date:
  • Size: 2.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.7.tar.gz
Algorithm Hash digest
SHA256 53c1d29fff116eb8b2a24d8eededcc5aea1957e96fe783ea1f7434ecb17b5ff3
MD5 578d7ca4589468a74658b3c825e5c364
BLAKE2b-256 410d94bbec03622e54d7af5333e88ae0c6082ca64de955ca652d123d0ae0c471

See more details on using hashes here.

File details

Details for the file torchplate-0.0.7-py3-none-any.whl.

File metadata

  • Download URL: torchplate-0.0.7-py3-none-any.whl
  • Upload date:
  • Size: 9.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.7-py3-none-any.whl
Algorithm Hash digest
SHA256 af94b8dcde7e0a38c693583c7ec6f9ea22aadfa418eeba8bd19015cadb462b62
MD5 f78d9bd2bb2ac97f4810d8dc5b62f90f
BLAKE2b-256 a8e59dbe7d42bfa9ed8be003b36b6e3136e5cd154cee3683b482b517d294e251

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page