PyTorch training manager (v1.0.7)
Project description
torchmanager
A highly-wrapped PyTorch training and testing manager
The main branch is used for beta unstable release. Please check stable branch for the latest main release version
Pre-request
- Python 3.8+
- PyTorch 1.8.2+
- tqdm
- tensorboard (Optional)
Installation
pip install torchmanager
The Manager
- Initialize the manager with target model, optimizer, loss function, and metrics:
import torch, torchmanager
# define model
class PytorchModel(torch.nn.Module):
...
# initialize model, optimizer, loss function, and metrics
model = PytorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torchmanager.losses.CrossEntropy()
metrics = {'accuracy': torchmanager.metrics.SparseCategoricalAccuracy()}
# initialize manager
manager = torchmanager.Manager(model, optimizer, loss_fn=loss_fn, metrics=metrics)
- Train the model with fit method:
from torch.utils.data import DataLoader
# get datasets
training_dataset: DataLoader = ...
val_dataset: DataLoader = ...
# train with fit method
manager.fit(training_dataset, epochs=10, val_dataset=val_dataset)
- Test the model with test method:
# get dataset
testing_dataset: DataLoader = ...
# test with test method
manager.test(testing_dataset)
- There are also some other callbacks to use:
...
tensorboard_callback = torchmanager.callbacks.TensorBoard('logs')
last_ckpt_callback = torchmanager.callbacks.LastCheckpoint(model, 'last.model')
manager.fit(..., callbacks_list=[tensorboard_callback, last_ckpt_callback])
Custom your training loop
- Create your own manager class by extending the
Manager
class:
...
class CustomManager(Manager):
...
- Override the
train_step
method:
class CustomManager(Manager):
...
def train_step(x_train: torch.Tensor, y_train: torch.Tensor) -> Dict[str, float]:
...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchmanager-1.0.7.tar.gz
(23.5 kB
view hashes)
Built Distribution
Close
Hashes for torchmanager-1.0.7-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 62a7db25e448d031585c27a449e5eb4d44d23650a7cec6bd36f1aa5a44ac3a89 |
|
MD5 | 9f2d890a463ea3642c1ebe0ebf9f1092 |
|
BLAKE2b-256 | fb90b57302a52eb4045c01f0f105f45ab47870e0c83d988178b998e533eefea2 |