Skip to main content

Focus on building and optimizing pytorch models not on training loops

Project description

torchtrainer

PyTorch model training made simpler without loosing control. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample and Keras. Further, besides applying Epoch Callbacks it also allows to call Callbacks every time after a specific number of batches passed (iterations) for long epoch durations.

Build Status codecov

Features

  • Torchtrainer
  • Logging utilities
  • Metrics
  • Visdom Visualization
  • Learning Rate Scheduler
  • Checkpointing
  • Flexible for muliple data inputs
  • Setup validation after every ... batches

Usage

Installation

pip install torchtrainer

Example

from torch import nn
from torch.optim import SGD
from torchtrainer import TorchTrainer
from torchtrainer.callbacks import VisdomLinePlotter, ProgressBar, VisdomEpoch, Checkpoint, CSVLogger, \
    EarlyStoppingEpoch, ReduceLROnPlateauCallback
from torchtrainer.metrics import BinaryAccuracy


metrics = [BinaryAccuracy()]

train_loader = ...
val_loader = ...

model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')


# Setup the callbacks of your choice

callbacks = [
    ProgressBar(log_every=10),
    VisdomEpoch(plotter, on_iteration_every=10),
    VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
    CSVLogger('test.csv'),
    Checkpoint('./model'),
    EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
    ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]

trainer = TorchTrainer(model)

# function to transform batch into inputs to your model and y_true values
# if your model accepts multiple inputs, just put all inputs into a tuple (input1, input2), y_true
def transform_fn(batch):
    inputs, y_true = batch
    return inputs, y_true.float()

# prepare your trainer for training
trainer.prepare(optimizer,
                loss,
                train_loader,
                val_loader,
                transform_fn=transform_fn,
                callbacks=callbacks,
                metrics=metrics)

# train your model
result = trainer.train(epochs=10, batch_size=10)

Callbacks

Logger

  • CSVLogger
  • CSVLoggerIteration
  • ProgressBar

Visualization and Logging

  • VisdomEpoch

Optimizers

  • ReduceLROnPlateauCallback
  • StepLRCallback

Regularization

  • EarlyStoppingEpoch
  • EarlyStoppingIteration

Checkpointing

  • Checkpoint
  • CheckpointIteration

Metrics

Currently only BinaryAccuracy is implemented. To implement other Metrics use the abstract base metric class torchtrainer.metrics.metric.Metric.

TODO

  • more tests
  • metrics

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for torchtrainer, version 0.3.9
Filename, size File type Python version Upload date Hashes
Filename, size torchtrainer-0.3.9-py3-none-any.whl (14.5 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size torchtrainer-0.3.9.tar.gz (11.0 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page