Focus on building and optimizing pytorch models not on training loops
Project description
torchtrainer
PyTorch model training made simpler. Focus on optimizing your model! Concepts are heavily inspired by the awesome project torchsample.
Features
Torchtrainer
- Logging utilities
- Metrics
- Visdom Visualization
- Learning Rate Scheduler
- Checkpointing
- Flexible for muliple data inputs
- Setup validation after every ... batches
Usage
Installation
pip install torchtrainer
Example
from torch import nn
from torch.optim import SGD
from torchtrainer.callbacks.checkpoint import Checkpoint
from torchtrainer.callbacks.csv_logger import CSVLogger
from torchtrainer.callbacks.early_stopping import EarlyStoppingEpoch
from torchtrainer.callbacks.progressbar import ProgressBar
from torchtrainer.callbacks.reducelronplateau import ReduceLROnPlateauCallback
from torchtrainer.callbacks.visdom import VisdomLinePlotter, VisdomEpoch
from torchtrainer.metrics.binary_accuracy import BinaryAccuracy
from torchtrainer.trainer import TorchTrainer
def transform_fn(batch):
inputs, y_true = batch
return inputs, y_true.float()
metrics = [BinaryAccuracy()]
train_loader = ...
val_loader = ...
model = ...
loss = nn.BCELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
# Setup Visdom Environment for your modl
plotter = VisdomLinePlotter(env_name=f'Model {11}')
callbacks = [
ProgressBar(log_every=10),
VisdomEpoch(plotter, on_iteration_every=10),
VisdomEpoch(plotter, on_iteration_every=10, monitor='binary_acc'),
CSVLogger('test.log'),
Checkpoint('./model'),
EarlyStoppingEpoch(min_delta=0.1, monitor='val_running_loss', patience=10),
ReduceLROnPlateauCallback(factor=0.1, threshold=0.1, patience=2, verbose=True)
]
trainer = TorchTrainer(model)
trainer.prepare(optimizer,
loss,
train_loader,
val_loader,
transform_fn=transform_fn,
callbacks=callbacks,
metrics=metrics)
# train your model
trainer.train(epochs=10, batch_size=10)
TODO
- more tests
- metrics
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchtrainer-0.3.2.tar.gz
(10.1 kB
view hashes)
Built Distribution
Close
Hashes for torchtrainer-0.3.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1eb84de64e671a1b544c4f700b76f832240141ad98eb38f14b972671625bbd7f |
|
MD5 | ff7447612093374089e04586728010d4 |
|
BLAKE2b-256 | d51f02b12996a19edc553140be891e0858e5fb95a5d31c2ec29309ddc3134f2f |