Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

TorchZQ: a PyTorch experiment runner

Installation

Install from PyPI (latest):

pip install torchzq --pre --upgrade

A customized runner for MNIST classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import torchzq


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    class HParams(torchzq.Runner.HParams):
        lr: float = 1e-3

    hp: HParams

    def create_model(self):
        return Net()

    def create_dataloader(self, mode):
        hp = self.hp
        dataset = datasets.MNIST(
            "../data",
            train=mode == "training",
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )
        return DataLoader(
            dataset,
            batch_size=hp.batch_size,
            num_workers=hp.nj,
            shuffle=mode == mode.TRAIN,
            drop_last=mode == mode.TRAIN,
        )

    def create_metrics(self):
        metrics = super().create_metrics()

        def early_stop(count):
            if count >= 2:
                # the metric does not go down for the latest two validations
                self.hp.max_epochs = -1  # this terminates the training

        metrics.add_metric("val/nll_loss", [early_stop])
        return metrics

    def prepare_batch(self, batch, _):
        x, y = batch
        x = x.to(self.hp.device)
        y = y.to(self.hp.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = batch
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = batch
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    Runner().start()

Execute the runner

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.1.0.dev20220101011716.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

torchzq-1.1.0.dev20220101011716-py3-none-any.whl (14.5 kB view details)

Uploaded Python 3

File details

Details for the file torchzq-1.1.0.dev20220101011716.tar.gz.

File metadata

  • Download URL: torchzq-1.1.0.dev20220101011716.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20220101011716.tar.gz
Algorithm Hash digest
SHA256 f1275a04a10ed4cb1aa7e04daf57f6a0cd17e167f25e80df8561e45f197c0523
MD5 244d9528031a6b8469e55b21c6ace9b9
BLAKE2b-256 6f71a2e3a59e72187b8f0c1d449c070cc6bb17f2311f44629860a2ea674e3718

See more details on using hashes here.

File details

Details for the file torchzq-1.1.0.dev20220101011716-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.1.0.dev20220101011716-py3-none-any.whl
  • Upload date:
  • Size: 14.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20220101011716-py3-none-any.whl
Algorithm Hash digest
SHA256 72079de6fb2e839982542d580f5654b339105adfa939300e75bd745302b4706d
MD5 498c28b43c7a0fa3063c490abfc14632
BLAKE2b-256 74c6752816168a221a8f6b70f7b55a5afbe3ac194aa145e3810eacbbf747bf26

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page