Focus on building and optimizing pytorch models not on training loops
Project description
torchtrainer
PyTorch model training made simpler without loosing control. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample and Keras. Further, besides applying Epoch Callbacks it also allows to call Callbacks every time after a specific number of batches passed (iterations) for long epoch durations.
Features
Torchtrainer
- Logging utilities
- Metrics
- Visdom Visualization
- Learning Rate Scheduler
- Checkpointing
- Flexible for muliple data inputs
- Setup validation after every ... batches
Usage
Installation
pip install torchtrainer
Example
from torch import nn
from torch.optim import SGD
from torchtrainer import TorchTrainer
from torchtrainer.callbacks import VisdomLinePlotter, ProgressBar, VisdomEpoch, Checkpoint, CSVLogger, \
EarlyStoppingEpoch, ReduceLROnPlateauCallback
from torchtrainer.metrics import BinaryAccuracy
metrics = [BinaryAccuracy()]
train_loader = ...
val_loader = ...
model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')
# Setup the callbacks of your choice
callbacks = [
ProgressBar(log_every=10),
VisdomEpoch(plotter, on_iteration_every=10),
VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
CSVLogger('test.log'),
Checkpoint('./model'),
EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]
trainer = TorchTrainer(model)
# function to transform batch into inputs to your model and y_true values
# if your model accepts multiple inputs, just put all inputs into a tuple (input1, input2), y_true
def transform_fn(batch):
inputs, y_true = batch
return inputs, y_true.float()
# prepare your trainer for training
trainer.prepare(optimizer,
loss,
train_loader,
val_loader,
transform_fn=transform_fn,
callbacks=callbacks,
metrics=metrics)
# train your model
result = trainer.train(epochs=10, batch_size=10)
Callbacks
Logger
CSVLogger
CSVLoggerIteration
ProgressBar
Visualization and Logging
VisdomEpoch
Optimizers
ReduceLROnPlateauCallback
StepLRCallback
Regularization
EarlyStoppingEpoch
EarlyStoppingIteration
Checkpointing
Checkpoint
CheckpointIteration
Metrics
Currently only BinaryAccuracy
is implemented. To implement other Metrics use the abstract base metric class torchtrainer.metrics.metric.Metric
.
TODO
- more tests
- metrics
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchtrainer-0.3.5.tar.gz
(10.9 kB
view hashes)
Built Distribution
Close
Hashes for torchtrainer-0.3.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cf6295d1ab0b471d5e15037f0f62a420cad5a070fa30b9ecf2357195c37f60bc |
|
MD5 | 8ce908a658bdd535e1654feed02f9e30 |
|
BLAKE2b-256 | 874dcb513435d1c615ac612ef3258d8ea9ab981e862c28b66dcfb1c9365ced64 |