Skip to main content

Focus on building and optimizing pytorch models not on training loops

Project description

torchtrainer

PyTorch model training made simpler without loosing control. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample and Keras. Further, besides applying Epoch Callbacks it also allows to call Callbacks every time after a specific number of batches passed (iterations) for long epoch durations.

Build Status codecov

Features

  • Torchtrainer
  • Logging utilities
  • Metrics
  • Visdom Visualization
  • Learning Rate Scheduler
  • Checkpointing
  • Flexible for muliple data inputs
  • Setup validation after every ... batches

Usage

Installation

pip install torchtrainer

Example

from torch import nn
from torch.optim import SGD
from torchtrainer import TorchTrainer
from torchtrainer.callbacks import VisdomLinePlotter, ProgressBar, VisdomEpoch, Checkpoint, CSVLogger, \
    EarlyStoppingEpoch, ReduceLROnPlateauCallback
from torchtrainer.metrics import BinaryAccuracy


metrics = [BinaryAccuracy()]

train_loader = ...
val_loader = ...

model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')


# Setup the callbacks of your choice

callbacks = [
    ProgressBar(log_every=10),
    VisdomEpoch(plotter, on_iteration_every=10),
    VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
    CSVLogger('test.csv'),
    Checkpoint('./model'),
    EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
    ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]

trainer = TorchTrainer(model)

# function to transform batch into inputs to your model and y_true values
# if your model accepts multiple inputs, just put all inputs into a tuple (input1, input2), y_true
def transform_fn(batch):
    inputs, y_true = batch
    return inputs, y_true.float()

# prepare your trainer for training
trainer.prepare(optimizer,
                loss,
                train_loader,
                val_loader,
                transform_fn=transform_fn,
                callbacks=callbacks,
                metrics=metrics)

# train your model
result = trainer.train(epochs=10, batch_size=10)

Callbacks

Logger

  • CSVLogger
  • CSVLoggerIteration
  • ProgressBar

Visualization and Logging

  • VisdomEpoch

Optimizers

  • ReduceLROnPlateauCallback
  • StepLRCallback

Regularization

  • EarlyStoppingEpoch
  • EarlyStoppingIteration

Checkpointing

  • Checkpoint
  • CheckpointIteration

Metrics

Currently only BinaryAccuracy is implemented. To implement other Metrics use the abstract base metric class torchtrainer.metrics.metric.Metric.

TODO

  • more tests
  • metrics

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchtrainer-0.3.9.tar.gz (11.0 kB view details)

Uploaded Source

Built Distribution

torchtrainer-0.3.9-py3-none-any.whl (14.5 kB view details)

Uploaded Python 3

File details

Details for the file torchtrainer-0.3.9.tar.gz.

File metadata

  • Download URL: torchtrainer-0.3.9.tar.gz
  • Upload date:
  • Size: 11.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.6.8 Darwin/18.6.0

File hashes

Hashes for torchtrainer-0.3.9.tar.gz
Algorithm Hash digest
SHA256 047cbbf7d92b9d7759666dead1e2847ef6c1ffe142fd9d57764bbde74e62ee4a
MD5 90f67f88bad058690c59194472d5c13a
BLAKE2b-256 bff9eb04c322b7d8aaa30083f484b767ddc06594f9f06a02eaeadb71254e3d21

See more details on using hashes here.

File details

Details for the file torchtrainer-0.3.9-py3-none-any.whl.

File metadata

  • Download URL: torchtrainer-0.3.9-py3-none-any.whl
  • Upload date:
  • Size: 14.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/0.12.17 CPython/3.6.8 Darwin/18.6.0

File hashes

Hashes for torchtrainer-0.3.9-py3-none-any.whl
Algorithm Hash digest
SHA256 73c190c26037e4876c24d9bb20b930423c73f964bf711632861e7b3354a9feaf
MD5 937036b2dfbff0fb3f7051d713866a01
BLAKE2b-256 0ff2dfd32580bed08a4dd01c4d1d9b6f95fef5d9041f0a975d3bdb1f2f1150e5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page