Skip to main content

TorchZQ: A simple PyTorch experiment runner.

Project description

TorchZQ: A PyTorch experiment runner based on Zouqi

Installation

Install from PyPI:

pip install torchzq

Install the latest version:

pip install git+https://github.com/enhuiz/torchzq@main

An Example for MNIST Classification

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import torchzq

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


class Runner(torchzq.Runner):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_model(self):
        return Net()

    def create_dataset(self):
        return datasets.MNIST(
            "../data",
            train=self.training,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        )

    def prepare_batch(self, batch):
        x, y = batch
        x = x.to(self.args.device)
        y = y.to(self.args.device)
        return x, y

    def training_step(self, batch, optimizer_index):
        x, y = self.prepare_batch(batch)
        loss = F.nll_loss(self.model(x), y)
        return loss, {"nll_loss": loss.item()}

    @torch.no_grad()
    def testing_step(self, batch, batch_index):
        x, y = self.prepare_batch(batch)
        y_ = self.model(x).argmax(dim=-1)
        return {"accuracy": (y_ == y).float().mean().item()}


if __name__ == "__main__":
    torchzq.start(Runner)

Run an Example

Training

$ tzq example/config/mnist.yml train

Testing

$ tzq example/config/mnist.yml test

TensorBoard

$ tensorboard --logdir runs

Supported Features

  • Model checkpoints
  • Logging
  • Gradient accumulation
  • Configuration file
  • TensorBoard
  • FP16

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchzq-1.0.8.tar.gz (12.1 kB view details)

Uploaded Source

Built Distribution

torchzq-1.0.8-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file torchzq-1.0.8.tar.gz.

File metadata

  • Download URL: torchzq-1.0.8.tar.gz
  • Upload date:
  • Size: 12.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for torchzq-1.0.8.tar.gz
Algorithm Hash digest
SHA256 dcf55175881c151c688652d15e9d94d8eae6c2d4c39b554117c815d0f6c655dd
MD5 2c77f1a96f1d817baec72421d983a515
BLAKE2b-256 8201c1bfce77d4ca601b92b1756b70b3ba0c0fec5971d0b0f5439966f619133a

See more details on using hashes here.

File details

Details for the file torchzq-1.0.8-py3-none-any.whl.

File metadata

  • Download URL: torchzq-1.0.8-py3-none-any.whl
  • Upload date:
  • Size: 12.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.5

File hashes

Hashes for torchzq-1.0.8-py3-none-any.whl
Algorithm Hash digest
SHA256 b895b5dec2f363de75d12ef6740b008a712ff87916207e15598ff84146058bc0
MD5 265cdc111684199f7c49a168195b17c5
BLAKE2b-256 0628aaff11afb4bfab4a73d8ce6f890dcb99b6bc5092c0de7dde89c749a10c50

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page