Skip to main content

TorchZQ: A PyTorch experiment runner.

Project description

TorchZQ: A PyTorch experiment runner built with zouqi

Installation

Install from PyPI:

pip install torchzq

Install the latest version:

pip install git+https://github.com/enhuiz/torchzq@main

An Example for MNIST Classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import torchzq

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_model(self):
        return Net()

    def create_dataset(self):
        return datasets.MNIST(
            "../data",
            train=self.training,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def prepare_batch(self, batch):
        x, y = batch
        x = x.to(self.args.device)
        y = y.to(self.args.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = self.prepare_batch(batch)
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = self.prepare_batch(batch)
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    torchzq.start(Runner)

Run an Example

Training

tzq example/config/mnist.yml train

Testing

tzq example/config/mnist.yml test

Weights & Biases

Before you run, login Weights & Biases first.

pip install wandb # install weight & bias client
wandb login       # login

Supported Features

  • Model checkpoints
  • Logging (Weights & Biases)
  • Gradient accumulation
  • Configuration file
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.0.10.dev20210906183416.tar.gz (13.0 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file torchzq-1.0.10.dev20210906183416.tar.gz.

File metadata

  • Download URL: torchzq-1.0.10.dev20210906183416.tar.gz
  • Upload date:
  • Size: 13.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.6

File hashes

Hashes for torchzq-1.0.10.dev20210906183416.tar.gz
Algorithm Hash digest
SHA256 5fa951d4080dba9ce5c0f7b5c14d1309240d42f6a0570e240e3e632cf0d53f8f
MD5 f84807af26495c7d6db5d33cca517b2f
BLAKE2b-256 60acc6c3b45fce43b877981f1f29eb1906b215c2526e0535600b6b5371c083f0

See more details on using hashes here.

File details

Details for the file torchzq-1.0.10.dev20210906183416-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.0.10.dev20210906183416-py3-none-any.whl
  • Upload date:
  • Size: 14.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.2 CPython/3.9.6

File hashes

Hashes for torchzq-1.0.10.dev20210906183416-py3-none-any.whl
Algorithm Hash digest
SHA256 2fe4a8195568c6b58c074901e1aea873d5e448fa3034a5c5c2ff61416e96cac6
MD5 75a84afee4138398349b1cc02088acde
BLAKE2b-256 f0d3ed094d2b7347260c48292f59345de90a51054e010a3406bb914367e3f149

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page