Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

TorchZQ: A PyTorch experiment runner built with zouqi

Installation

Install from PyPI:

pip install torchzq

Install the latest version:

pip install git+https://github.com/enhuiz/torchzq@main

An Example for MNIST Classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import torchzq

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_model(self):
        return Net()

    def create_dataset(self):
        return datasets.MNIST(
            "../data",
            train=self.training,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def prepare_batch(self, batch):
        x, y = batch
        x = x.to(self.args.device)
        y = y.to(self.args.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = self.prepare_batch(batch)
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = self.prepare_batch(batch)
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    torchzq.start(Runner)

Run an Example

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported Features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.0.10.dev20211121215433.tar.gz (13.4 kB view details)

Uploaded Source

Built Distribution

torchzq-1.0.10.dev20211121215433-py3-none-any.whl (14.8 kB view details)

Uploaded Python 3

File details

Details for the file torchzq-1.0.10.dev20211121215433.tar.gz.

File metadata

  • Download URL: torchzq-1.0.10.dev20211121215433.tar.gz
  • Upload date:
  • Size: 13.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for torchzq-1.0.10.dev20211121215433.tar.gz
Algorithm Hash digest
SHA256 58095f39fe16dc5c07e5cf2b797c1df173fa46b5ee6ad9c203185fa32692a040
MD5 3318200fac70884b5f98be781e6d50f3
BLAKE2b-256 e583310b99de3e05f0bfa1084ed9b0ee1b8687a90f9f8060339e1c4938abcc66

See more details on using hashes here.

File details

Details for the file torchzq-1.0.10.dev20211121215433-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.0.10.dev20211121215433-py3-none-any.whl
  • Upload date:
  • Size: 14.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.6.0 importlib_metadata/4.8.2 pkginfo/1.8.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.0

File hashes

Hashes for torchzq-1.0.10.dev20211121215433-py3-none-any.whl
Algorithm Hash digest
SHA256 1c2987545cbc73e878553877827f33b505dd59d6922923b07c1415c4abdc33da
MD5 d0cfbda6e770ec03df8eb086ca0e9a3c
BLAKE2b-256 79426e37b638bb04c73c9cf44f927835d2b84ca9bc80cc5d7d680e12ad93aa8c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page