Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

TorchZQ: a PyTorch experiment runner

Installation

Install from PyPI (latest):

pip install torchzq --pre --upgrade

A customized runner for MNIST classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import torchzq


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    class HParams(torchzq.Runner.HParams):
        lr: float = 1e-3

    hp: HParams

    def create_model(self):
        return Net()

    def create_dataloader(self, mode):
        hp = self.hp
        dataset = datasets.MNIST(
            "../data",
            train=mode == "training",
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )
        return DataLoader(
            dataset,
            batch_size=hp.batch_size,
            num_workers=hp.nj,
            shuffle=mode == mode.TRAIN,
            drop_last=mode == mode.TRAIN,
        )

    def create_metrics(self):
        metrics = super().create_metrics()

        def early_stop(count):
            if count >= 2:
                # the metric does not go down for the latest two validations
                self.hp.max_epochs = -1  # this terminates the training

        metrics.add_metric("val/nll_loss", [early_stop])
        return metrics

    def prepare_batch(self, batch, _):
        x, y = batch
        x = x.to(self.hp.device)
        y = y.to(self.hp.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = batch
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = batch
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    Runner().start()

Execute the runner

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.1.0.dev20220109163729.tar.gz (13.6 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file torchzq-1.1.0.dev20220109163729.tar.gz.

File metadata

  • Download URL: torchzq-1.1.0.dev20220109163729.tar.gz
  • Upload date:
  • Size: 13.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20220109163729.tar.gz
Algorithm Hash digest
SHA256 ba2432e3e4b854bdce69b88693df3727172df1e93fc5829c731b0e00bfe6f6b8
MD5 4469805eda9c4523a4c8763af16a4ba4
BLAKE2b-256 a7630119bfa7288615bc0626ae9bc7d462ed77c8b5dfa7ed0feeef667eb380f8

See more details on using hashes here.

File details

Details for the file torchzq-1.1.0.dev20220109163729-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.1.0.dev20220109163729-py3-none-any.whl
  • Upload date:
  • Size: 14.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.27.1 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1

File hashes

Hashes for torchzq-1.1.0.dev20220109163729-py3-none-any.whl
Algorithm Hash digest
SHA256 5834530783024e7789646ac1a5513119638b65c910feebc9a5adf3cf9cdb9c30
MD5 276e926b7f5423701c0ee84ac84cf3ab
BLAKE2b-256 48396b048916a325e9190933d91082f3b9decc9a5836fa408ccc99abb0a075da

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page