Skip to main content

Focus on building and optimizing pytorch models not on training loops

Project description

torchtrainer

PyTorch model training made simpler. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample and Keras.

Build Status codecov

Features

  • Torchtrainer
  • Logging utilities
  • Metrics
  • Visdom Visualization
  • Learning Rate Scheduler
  • Checkpointing
  • Flexible for muliple data inputs
  • Setup validation after every ... batches

Usage

Installation

pip install torchtrainer

Example

from torch import nn
from torch.optim import SGD
from torchtrainer.callbacks.checkpoint import Checkpoint
from torchtrainer.callbacks.csv_logger import CSVLogger
from torchtrainer.callbacks.early_stopping import EarlyStoppingEpoch
from torchtrainer.callbacks.progressbar import ProgressBar
from torchtrainer.callbacks.reducelronplateau import ReduceLROnPlateauCallback
from torchtrainer.callbacks.visdom import VisdomLinePlotter, VisdomEpoch
from torchtrainer.metrics.binary_accuracy import BinaryAccuracy
from torchtrainer.trainer import TorchTrainer


def transform_fn(batch):
    inputs, y_true = batch
    return inputs, y_true.float()


metrics = [BinaryAccuracy()]

train_loader = ...
val_loader = ...

model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')

callbacks = [
    ProgressBar(log_every=10),
    VisdomEpoch(plotter, on_iteration_every=10),
    VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
    CSVLogger('test.log'),
    Checkpoint('./model'),
    EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
    ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]

trainer = TorchTrainer(model)
trainer.prepare(optimizer,
                loss,
                train_loader,
                val_loader,
                transform_fn=transform_fn,
                callbacks=callbacks,
                metrics=metrics)

# train your model
trainer.train(epochs=10, batch_size=10)

TODO

  • more tests
  • metrics

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtrainer-0.3.4.tar.gz (10.2 kB view hashes)

Uploaded Source

Built Distribution

torchtrainer-0.3.4-py3-none-any.whl (14.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page