TorchZQ: A PyTorch experiment runner.
Project description
TorchZQ: a PyTorch experiment runner
Installation
Install from PyPI (latest):
pip install torchzq --pre --upgrade
A customized runner for MNIST classification
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchzq
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class Runner(torchzq.Runner):
class HParams(torchzq.Runner.HParams):
lr: float = 1e-3
hp: HParams
def create_model(self):
return Net()
def create_dataloader(self, mode):
hp = self.hp
dataset = datasets.MNIST(
"../data",
train=mode == "training",
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
return DataLoader(
dataset,
batch_size=hp.batch_size,
num_workers=hp.nj,
shuffle=mode == mode.TRAIN,
drop_last=mode == mode.TRAIN,
)
def create_metrics(self):
metrics = super().create_metrics()
def early_stop(count):
if count >= 2:
# the metric does not go down for the latest two validations
self.hp.max_epochs = -1 # this terminates the training
metrics.add_metric("val/nll_loss", [early_stop])
return metrics
def prepare_batch(self, batch, _):
x, y = batch
x = x.to(self.hp.device)
y = y.to(self.hp.device)
return x, y
def training_step(self, batch, optimizer_index):
x, y = batch
loss = F.nll_loss(self.model(x), y)
return loss, {"nll_loss": loss.item()}
@torch.no_grad()
def testing_step(self, batch, batch_index):
x, y = batch
y_ = self.model(x).argmax(dim=-1)
return {"accuracy": (y_ == y).float().mean().item()}
if __name__ == "__main__":
Runner().start()
Execute the runner
Training
tzq example/config/mnist.yml train
Testing
tzq example/config/mnist.yml test
Weights & Biases
Before you run, login Weights & Biases first.
pip install wandb # install weight & bias client
wandb login # login
Supported features
- Model checkpoints
- Logging (Weights & Biases)
- Gradient accumulation
- Configuration file
- FP16
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torchzq-1.1.0.dev20211222222933.tar.gz.
File metadata
- Download URL: torchzq-1.1.0.dev20211222222933.tar.gz
- Upload date:
- Size: 13.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
dd9741d2140dafccbddb10ae00bb8b724a67f2bf4b6bb34737ae2206fe87dade
|
|
| MD5 |
0241615c1b832ca80386a2296831f46c
|
|
| BLAKE2b-256 |
f1cf6d3f6ec923e37de7e6c142a1eb50b62b6141c01c85478a65d14dcca1e82e
|
File details
Details for the file torchzq-1.1.0.dev20211222222933-py3-none-any.whl.
File metadata
- Download URL: torchzq-1.1.0.dev20211222222933-py3-none-any.whl
- Upload date:
- Size: 14.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.7.1 importlib_metadata/4.10.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.10.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
804acdc6a653b6c3581798de665e9192c516ce7f51dc56f0ffd6916deecba9fb
|
|
| MD5 |
58c40e59bb46727bfb04f149c4600752
|
|
| BLAKE2b-256 |
fceed46bfdd3dfa82666a10f256058ec0700f888b64ecd7e41408ec2a61faab4
|