Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

torchzq: a PyTorch experiment runner

Installation

Install from PyPI:

pip install torchzq

Install the latest version:

pip install git+https://github.com/enhuiz/torchzq@main

A customized runner for MNIST classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchzq

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    def create_model(self):
        return Net()

    def create_dataset(self):
        return datasets.MNIST(
            "../data",
            train=self.training,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def prepare_batch(self, batch):
        x, y = batch
        x = x.to(self.args.device)
        y = y.to(self.args.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = self.prepare_batch(batch)
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = self.prepare_batch(batch)
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    torchzq.start(Runner)

Execute the runner

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.1.0.dev20211222222438.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

torchzq-1.1.0.dev20211222222438-py3-none-any.whl (14.3 kB view details)

Uploaded Python 3

File details

Details for the file torchzq-1.1.0.dev20211222222438.tar.gz.

File metadata

  • Download URL: torchzq-1.1.0.dev20211222222438.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20211222222438.tar.gz
Algorithm Hash digest
SHA256 dc2b5802d8c14a3a12f05fef9fbce3b4b9830d1ad52a55ecd1a008523f3b98cb
MD5 5c45557531a3068a2ef899420a76eca8
BLAKE2b-256 7b95cecd8eb1ac6850b54c68dbef414eb4b18dff89f41adc29f458b63db1bd2a

See more details on using hashes here.

File details

Details for the file torchzq-1.1.0.dev20211222222438-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.1.0.dev20211222222438-py3-none-any.whl
  • Upload date:
  • Size: 14.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20211222222438-py3-none-any.whl
Algorithm Hash digest
SHA256 f921895a6401baa802ea1baba71b2fd540613ac5a530ab24e2d0ee0df5988e2b
MD5 f068d0545957531e07654bf43303ae56
BLAKE2b-256 128a81dc99b462f60da43ef183a79ee69592e4f83124b42857fc18c70372a083

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page