Skip to main content

A minimal and simple machine learning experiment module for PyTorch.

Project description

torchplate: Minimal Experiment Workflows in PyTorch

(Github | PyPI | Documentation)

An extremely minimal and simple experiment module for machine learning in PyTorch.

In addition to abstracting away the training loop, we provide several abstractions to improve the efficiency of machine learning workflows with PyTorch.

Example

To get started, create an experiment child class of torchplate.experiment.Experiment and provide several key, experiment-unique items: model, optimizer, and a training set dataloader. Add whatever custom methods you may want to this class. Then starting training! That's it!

import torchplate
import data 
import models
import torch
import torch.optim as optim
import torch.nn as nn


class SampleExp(torchplate.experiment.Experiment):
    def __init__(self): 
        self.model = models.Net()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        dataset = data.load_set('cifar')
        # use various torchplate.utils to improve efficiency of common workflows 
        self.trainloader, self.testloader = torchplate.utils.get_xy_loaders(dataset)

        # inherit from torchplate.experiment.Experiment and pass in
        # model, optimizer, and dataloader 
        super().__init__(
            model = self.model,
            optimizer = self.optimizer,
            trainloader = self.trainloader 
        )
    
    # provide this abstract method to calculate loss 
    def evaluate(self, batch):
        x, y = batch
        logits = self.model(x)
        loss_val = self.criterion(logits, y)
        return loss_val


exp = SampleExp()
exp.train(num_epochs=5)

output:

Epoch 1: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 293.98it/s]
Training Loss (epoch 1): 1.3564644632516083
Epoch 2: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 598.46it/s]
Training Loss (epoch 2): 1.2066593832439847
Epoch 3: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 579.40it/s]
Training Loss (epoch 3): 1.1030386642173484
Epoch 4: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 563.90it/s]
Training Loss (epoch 4): 1.0885229706764221
Epoch 5: 100%|███████████████████████████████████████████████████████████| 27/27 [00:00<00:00, 577.54it/s]
Training Loss (epoch 5): 1.0520343957123932
Finished Training!

Installation

$ pip install torchplate

Changelog

0.0.3

  • Added verbose option as well as wandb logging

0.0.2

  • Fixed a polymorphic bug

0.0.1

  • First version published. Provides basic data-loading utilities and the base experiment module.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchplate-0.0.3.tar.gz (2.8 MB view details)

Uploaded Source

Built Distribution

torchplate-0.0.3-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file torchplate-0.0.3.tar.gz.

File metadata

  • Download URL: torchplate-0.0.3.tar.gz
  • Upload date:
  • Size: 2.8 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.3.tar.gz
Algorithm Hash digest
SHA256 a8e39bd859ed5e3979844a53d2d594521ddedcc3ae7c42ffd7323cdee6f3baee
MD5 104d80510a0c2cc7e414613ef217f6a8
BLAKE2b-256 41eb5236c386b6f28dbd73a38cba48e897c486a9c0915eec587cc5cf841e97af

See more details on using hashes here.

File details

Details for the file torchplate-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: torchplate-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.8.5

File hashes

Hashes for torchplate-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 76706074ec351ef63dec9c15da5a3c1bc40bca5ff7605d8b57180f7953af37a4
MD5 e2b1699a8cdc82cbd39c14db17c2eac4
BLAKE2b-256 9a912a5c2efa03ec0c41c964a2000c102fc168fbc6f26295b4494044caf0c198

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page