Skip to main content

PyTorch training manager (v1.0.2)

Project description

torchmanager

A Keras like PyTorch training and testing manager

Pre-request

  • Python 3.8+
  • PyTorch 1.8.2+
  • tqdm

Installation

pip install torchmanager

The Manager

  • Initialize the manager with target model, optimizer, loss function, and metrics:
import torch, torchmanager

# define model
class PytorchModel(torch.nn.Module):
    ...

# initialize model, optimizer, loss function, and metrics
model = PytorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
loss_fn = torchmanager.losses.CrossEntropy()
metrics = {'accuracy': torchmanager.metrics.SparseCategoricalAccuracy()}

# initialize manager
manager = torchmanager.Manager(model, optimizer, loss_fn=loss_fn, metrics=metrics)
  • Train the model with fit method:
from torch.utils.data import DataLoader

# get datasets
training_dataset: DataLoader = ...
val_dataset: DataLoader = ...

# train with fit method
manager.fit(training_dataset, epochs=10, val_dataset=val_dataset)
  • Test the model with test method:
# get dataset
testing_dataset: DataLoader = ...

# test with test method
manager.test(testing_dataset)
  • There are also some Keras-like callbacks to use:
...

tensorboard_callback = torchmanager.callbacks.TensorBoard('logs')
last_ckpt_callback = torchmanager.callbacks.LastCheckpoint(model, 'last.model')
manager.fit(..., callbacks_list=[tensorboard_callback, last_ckpt_callback])

Custom your training loop

  1. Create your own manager class by extending the Manager class:
...

class CustomManager(Manager):
    ...
  1. Override the train_step method:
class CustomManager(Manager):
    ...
    
    def train_step(x_train: torch.Tensor, y_train: torch.Tensor) -> Dict[str, float]:
        ...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchmanager-1.0.2.tar.gz (16.9 kB view hashes)

Uploaded Source

Built Distribution

torchmanager-1.0.2-py3-none-any.whl (21.4 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page