PyTorch Training Manager v1.2
Project description
torchmanager
A highly-wrapped PyTorch training and testing manager
Pre-request
- Python 3.8+
- PyTorch
- tqdm
- scipy (Optional)
- tensorboard (Optional)
Installation
- PyPi:
pip install --pre torchmanager
- Conda:
conda install -c kisonho torchmanager-nightly
The Manager
- Start with configurations:
from torchmanager.configs import Configs as _Configs
# define necessary configurations
class Configs(_Configs):
epochs: int
lr: float
...
def get_arguments(parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup] = argparse.ArgumentParser()) -> Union[argparse.ArgumentParser, argparse._ArgumentGroup]:
'''Add arguments to argument parser'''
...
def show_settings(self) -> None:
...
# get configs from terminal arguments
configs = Configs.from_arguments()
- Initialize the manager with target model, optimizer, loss function, and metrics:
import torch, torchmanager
# define model
class PytorchModel(torch.nn.Module):
...
# initialize model, optimizer, loss function, and metrics
model = PytorchModel(...)
optimizer = torch.optim.SGD(model.parameters(), lr=configs.lr)
loss_fn = torchmanager.losses.CrossEntropy()
metrics = {'accuracy': torchmanager.metrics.SparseCategoricalAccuracy()}
# initialize manager
manager = torchmanager.Manager(model, optimizer, loss_fn=loss_fn, metrics=metrics)
- Train the model with fit method:
from torchmanager.data import Dataset
# get datasets
training_dataset: Dataset = ...
val_dataset: Dataset = ...
# train with fit method
manager.fit(training_dataset, epochs=configs.epochs, val_dataset=val_dataset)
- There are also some other callbacks to use:
...
tensorboard_callback = torchmanager.callbacks.TensorBoard('logs') # tensorboard dependency required
last_ckpt_callback = torchmanager.callbacks.LastCheckpoint(manager, 'last.model')
model = manager.fit(..., callbacks_list=[tensorboard_callback, last_ckpt_callback])
- Or use
callbacks.Experiment
to handle bothcallbacks.TensorBoard
andcallbacks.LastCheckpoint
:
...
exp_callback = torchmanager.callbacks.Experiment('test.exp', manager) # tensorboard dependency required
model = manager.fit(..., callbacks_list=[exp_callback])
- Test the model with test method:
# get dataset
testing_dataset: Dataset = ...
# test with test method
manager.test(testing_dataset)
- Save final model in PyTorch format:
torch.save(model, "model.pth")
Custom your training loop
- Create your own manager class by extending the
Manager
class:
...
class CustomManager(Manager):
...
- Override the
train_step
method:
class CustomManager(Manager):
...
def train_step(x_train: torch.Tensor, y_train: torch.Tensor) -> Dict[str, float]:
...
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchmanager-1.2.tar.gz
(35.5 kB
view hashes)
Built Distribution
torchmanager-1.2-py3-none-any.whl
(52.9 kB
view hashes)
Close
Hashes for torchmanager-1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a9df0bcd5a154fa3047b18ce5257f2a986870255b9658dba6a038ea0f59b0c6 |
|
MD5 | b1f477605c04b0d7886666ca7b2a9cd0 |
|
BLAKE2b-256 | 82819e678dbec7443a8a308c7088778a624bdd5f559c68dce6447a8d6e233ba6 |