TorchZQ: A PyTorch experiment runner.
Project description
TorchZQ: a PyTorch experiment runner
Installation
Install from PyPI (latest):
pip install torchzq --pre --upgrade
A customized runner for MNIST classification
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchzq
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class Runner(torchzq.Runner):
class HParams(torchzq.Runner.HParams):
lr: float = 1e-3
hp: HParams
def create_model(self):
return Net()
def create_dataloader(self, mode):
hp = self.hp
dataset = datasets.MNIST(
"../data",
train=mode == "training",
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
return DataLoader(
dataset,
batch_size=hp.batch_size,
num_workers=hp.nj,
shuffle=mode == mode.TRAIN,
drop_last=mode == mode.TRAIN,
)
def create_metrics(self):
metrics = super().create_metrics()
def early_stop(count):
if count >= 2:
# the metric does not go down for the latest two validations
self.hp.max_epochs = -1 # this terminates the training
metrics.add_metric("val/nll_loss", [early_stop])
return metrics
def prepare_batch(self, batch, _):
x, y = batch
x = x.to(self.hp.device)
y = y.to(self.hp.device)
return x, y
def training_step(self, batch, optimizer_index):
x, y = batch
loss = F.nll_loss(self.model(x), y)
return loss, {"nll_loss": loss.item()}
@torch.no_grad()
def testing_step(self, batch, batch_index):
x, y = batch
y_ = self.model(x).argmax(dim=-1)
return {"accuracy": (y_ == y).float().mean().item()}
if __name__ == "__main__":
Runner().start()
Execute the runner
Training
tzq example/config/mnist.yml train
Testing
tzq example/config/mnist.yml test
Weights & Biases
Before you run, login Weights & Biases first.
pip install wandb # install weight & bias client
wandb login # login
Supported features
- Model checkpoints
- Logging (Weights & Biases)
- Gradient accumulation
- Configuration file
- FP16
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchzq-1.1.0.dev20211222222933.tar.gz
.
File metadata
- Download URL: torchzq-1.1.0.dev20211222222933.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | dd9741d2140dafccbddb10ae00bb8b724a67f2bf4b6bb34737ae2206fe87dade |
|
MD5 | 0241615c1b832ca80386a2296831f46c |
|
BLAKE2b-256 | f1cf6d3f6ec923e37de7e6c142a1eb50b62b6141c01c85478a65d14dcca1e82e |
File details
Details for the file torchzq-1.1.0.dev20211222222933-py3-none-any.whl
.
File metadata
- Download URL: torchzq-1.1.0.dev20211222222933-py3-none-any.whl
- Upload date:
- Size: 14.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 804acdc6a653b6c3581798de665e9192c516ce7f51dc56f0ffd6916deecba9fb |
|
MD5 | 58c40e59bb46727bfb04f149c4600752 |
|
BLAKE2b-256 | fceed46bfdd3dfa82666a10f256058ec0700f888b64ecd7e41408ec2a61faab4 |