Skip to main content

Focus on building and optimizing pytorch models not on training loops

Project description

torchtrainer

PyTorch model training made simpler without loosing control. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample and Keras. Further, besides applying Epoch Callbacks it also allows to call Callbacks every time after a specific number of batches passed (iterations) for long epoch durations.

Build Status codecov

Features

  • Torchtrainer
  • Logging utilities
  • Metrics
  • Visdom Visualization
  • Learning Rate Scheduler
  • Checkpointing
  • Flexible for muliple data inputs
  • Setup validation after every ... batches

Usage

Installation

pip install torchtrainer

Example

from torch import nn
from torch.optim import SGD
from torchtrainer import TorchTrainer
from torchtrainer.callbacks import VisdomLinePlotter, ProgressBar, VisdomEpoch, Checkpoint, CSVLogger, \
    EarlyStoppingEpoch, ReduceLROnPlateauCallback
from torchtrainer.metrics import BinaryAccuracy


metrics = [BinaryAccuracy()]

train_loader = ...
val_loader = ...

model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')


# Setup the callbacks of your choice

callbacks = [
    ProgressBar(log_every=10),
    VisdomEpoch(plotter, on_iteration_every=10),
    VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
    CSVLogger('test.csv'),
    Checkpoint('./model'),
    EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
    ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]

trainer = TorchTrainer(model)

# function to transform batch into inputs to your model and y_true values
# if your model accepts multiple inputs, just put all inputs into a tuple (input1, input2), y_true
def transform_fn(batch):
    inputs, y_true = batch
    return inputs, y_true.float()

# prepare your trainer for training
trainer.prepare(optimizer,
                loss,
                train_loader,
                val_loader,
                transform_fn=transform_fn,
                callbacks=callbacks,
                metrics=metrics)

# train your model
result = trainer.train(epochs=10, batch_size=10)

Callbacks

Logger

  • CSVLogger
  • CSVLoggerIteration
  • ProgressBar

Visualization and Logging

  • VisdomEpoch

Optimizers

  • ReduceLROnPlateauCallback
  • StepLRCallback

Regularization

  • EarlyStoppingEpoch
  • EarlyStoppingIteration

Checkpointing

  • Checkpoint
  • CheckpointIteration

Metrics

Currently only BinaryAccuracy is implemented. To implement other Metrics use the abstract base metric class torchtrainer.metrics.metric.Metric.

TODO

  • more tests
  • metrics

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtrainer-0.3.6.tar.gz (11.0 kB view hashes)

Uploaded Source

Built Distribution

torchtrainer-0.3.6-py3-none-any.whl (14.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page