Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

TorchZQ: A PyTorch experiment runner built with zouqi

Installation

Install from PyPI:

pip install torchzq

Install the latest version:

pip install git+https://github.com/enhuiz/torchzq@main

An Example for MNIST Classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import torchzq

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_model(self):
        return Net()

    def create_dataset(self):
        return datasets.MNIST(
            "../data",
            train=self.training,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def prepare_batch(self, batch):
        x, y = batch
        x = x.to(self.args.device)
        y = y.to(self.args.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = self.prepare_batch(batch)
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = self.prepare_batch(batch)
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    torchzq.start(Runner)

Run an Example

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported Features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.0.10.dev20211007230726.tar.gz (13.1 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file torchzq-1.0.10.dev20211007230726.tar.gz.

File metadata

  • Download URL: torchzq-1.0.10.dev20211007230726.tar.gz
  • Upload date:
  • Size: 13.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchzq-1.0.10.dev20211007230726.tar.gz
Algorithm Hash digest
SHA256 28659cf34c8be69e3e2c5aef73babdcc8c64ef83d3903aa666996bb3679cf09f
MD5 24eb300ca346e93f5ea04ed30ae75655
BLAKE2b-256 14e0f0518a2f48229c23049d551b795b4ef37f6091d664b20f70eb91ac2eb603

See more details on using hashes here.

File details

Details for the file torchzq-1.0.10.dev20211007230726-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.0.10.dev20211007230726-py3-none-any.whl
  • Upload date:
  • Size: 14.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchzq-1.0.10.dev20211007230726-py3-none-any.whl
Algorithm Hash digest
SHA256 7c0bf6b3a9378ebf5e05d1ab8cba4cd019a549d7c112d61160dd1af1e4b5d71e
MD5 1faca00968e9ef55fe58e49e9ad29a3c
BLAKE2b-256 77642f6d0c7dc20ab850064fcd4167d750e2a7e12633a44a4c3fafec95423584

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page