Skip to main content

A minimal and simple machine learning experiment module for PyTorch.

Project description

torchplate: Minimal Experiment Workflows in PyTorch

(Github | PyPI | Documentation)

Installation | Example | More examples | Starter project | Changelog

An extremely minimal and simple experiment module for machine learning in PyTorch (PyTorch + boilerplate = torchplate).

In addition to abstracting away the training loop, we provide several abstractions to improve the efficiency of machine learning workflows with PyTorch.

Installation

$ pip install torchplate

Example

To get started, create a child class of torchplate.experiment.Experiment and provide several key, experiment-unique items: model, optimizer, and a training set dataloader. Then, provide an implementation of the abstract method evaluate. This function takes in a batch from the trainloader and should return the loss (i.e., implement the forward pass + loss calculation). Add whatever custom methods you may want to this class. Then starting training! That's it!

import torchplate
from torchplate import experiment
from torchplate import utils
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import requests
import cloudpickle as cp
from urllib.request import urlopen


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3*32*32, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class CifarExp(torchplate.experiment.Experiment):
    def __init__(self): 
        self.model = Net()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        dataset = cp.load(urlopen("https://stanford.edu/~rsikand/assets/datasets/mini_cifar.pkl")) 
        # use various torchplate.utils to improve efficiency of common workflows 
        self.trainloader, self.testloader = torchplate.utils.get_xy_loaders(dataset)

        # inherit from torchplate.experiment.Experiment and pass in
        # model, optimizer, and dataloader 
        super().__init__(
            model = self.model,
            optimizer = self.optimizer,
            trainloader = self.trainloader,
            verbose = True
        )
    
    # provide this abstract method to calculate loss 
    def evaluate(self, batch):
        x, y = batch
        logits = self.model(x)
        loss_val = self.criterion(logits, y)
        return loss_val

    def test(self):
        accuracy_count = 0
        for x, y in self.testloader:
            logits = self.model(x)
            pred = torch.argmax(F.softmax(logits, dim=1)).item()
            print(f"Prediction: {pred}, True: {y.item()}")
            if pred == y:
                accuracy_count += 1
        print("Accuracy: ", accuracy_count/len(self.testloader))

    def on_epoch_end(self):
        # to illustrate the concept of callbacks 
        print("------------------ (Epoch end) --------------------")



exp = CifarExp()
exp.train(num_epochs=100)
exp.test()

More examples

See examples/cifar for another minimal example. See examples/starter for a full program example. To get started running your own experiments, you can use examples/starter as a base (or use cookiecutter as shown below).

Starter project

The starter branch holds the source for a cookiecutter project. This allows users to easily create projects from the starter code example by running a simple command. To get started, install cookiecutter and then type

$ cookiecutter https://github.com/rosikand/torchplate.git --checkout starter

which will generate the following structure for you to use as a base for your projects:

torchplate_starter
├── datasets.py
├── experiments.py
├── models.py
└── runner.py

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchplate-0.0.10.tar.gz (519.4 kB view details)

Uploaded Source

Built Distribution

torchplate-0.0.10-py3-none-any.whl (9.3 kB view details)

Uploaded Python 3

File details

Details for the file torchplate-0.0.10.tar.gz.

File metadata

  • Download URL: torchplate-0.0.10.tar.gz
  • Upload date:
  • Size: 519.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.10.tar.gz
Algorithm Hash digest
SHA256 f6afa15b65a464a378906d64cb4ca94526bb2c5d7c522968dd1258bb7711917a
MD5 d9d2089e4ede97c0803471dae5246c90
BLAKE2b-256 701dd1250cd23f10bd42ad41f00035df721e6ff9c67d322c5aee717a50bab633

See more details on using hashes here.

File details

Details for the file torchplate-0.0.10-py3-none-any.whl.

File metadata

  • Download URL: torchplate-0.0.10-py3-none-any.whl
  • Upload date:
  • Size: 9.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.10-py3-none-any.whl
Algorithm Hash digest
SHA256 4f81911ec506fd9b52e3a7cafca559ee864f2eb7ed672e065c65a30ea821a3ce
MD5 9fb5702c8a4eec96204d756c07e79b98
BLAKE2b-256 449d53bb62cb9f53de610ef8f10a0d77a2af378b03e1f4235943ac208e6d16d1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page