TorchZQ: A PyTorch experiment runner.
Project description
TorchZQ: A PyTorch experiment runner built with zouqi
Installation
Install from PyPI:
pip install torchzq
Install the latest version:
pip install git+https://github.com/enhuiz/torchzq@main
An Example for MNIST Classification
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchzq
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class Runner(torchzq.Runner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def create_model(self):
return Net()
def create_dataset(self):
return datasets.MNIST(
"../data",
train=self.training,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
def prepare_batch(self, batch):
x, y = batch
x = x.to(self.args.device)
y = y.to(self.args.device)
return x, y
def training_step(self, batch, optimizer_index):
x, y = self.prepare_batch(batch)
loss = F.nll_loss(self.model(x), y)
return loss, {"nll_loss": loss.item()}
@torch.no_grad()
def testing_step(self, batch, batch_index):
x, y = self.prepare_batch(batch)
y_ = self.model(x).argmax(dim=-1)
return {"accuracy": (y_ == y).float().mean().item()}
if __name__ == "__main__":
torchzq.start(Runner)
Run an Example
Training
tzq example/config/mnist.yml train
Testing
tzq example/config/mnist.yml test
Weights & Biases
Before you run, login Weights & Biases first.
pip install wandb # install weight & bias client
wandb login # login
Supported Features
- Model checkpoints
- Logging (Weights & Biases)
- Gradient accumulation
- Configuration file
- FP16
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchzq-1.0.10.dev20211007230726.tar.gz
.
File metadata
- Download URL: torchzq-1.0.10.dev20211007230726.tar.gz
- Upload date:
- Size: 13.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 28659cf34c8be69e3e2c5aef73babdcc8c64ef83d3903aa666996bb3679cf09f |
|
MD5 | 24eb300ca346e93f5ea04ed30ae75655 |
|
BLAKE2b-256 | 14e0f0518a2f48229c23049d551b795b4ef37f6091d664b20f70eb91ac2eb603 |
File details
Details for the file torchzq-1.0.10.dev20211007230726-py3-none-any.whl
.
File metadata
- Download URL: torchzq-1.0.10.dev20211007230726-py3-none-any.whl
- Upload date:
- Size: 14.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7c0bf6b3a9378ebf5e05d1ab8cba4cd019a549d7c112d61160dd1af1e4b5d71e |
|
MD5 | 1faca00968e9ef55fe58e49e9ad29a3c |
|
BLAKE2b-256 | 77642f6d0c7dc20ab850064fcd4167d750e2a7e12633a44a4c3fafec95423584 |