TorchFit is a simple, easy-to-use, and minimalistic training-helper for PyTorch
Project description
TorchFit
TorchFit
is a bare-bones, minimalistic training-helper for PyTorch that exposes an easy-to-use fit
method in the style of fastai and Keras.
TorchFit
is intended to be minimally-invasive with a tiny footprint and as little bloat as possible. It is well-suited to those that are new to training models in PyTorch. For more complex training scenarios (e.g., training GANs, multi-node GPU training), PyTorch Lightning is highly recommended.
Usage
# normal PyTorch stuff
train_loader = create_your_training_data_loader()
val_loader = create_your_validation_data_loader()
test_loader = create_your_test_data_loader()
model = create_your_pytorch_model()
# wrap model and data in Learner
import torchfit
learner = torchfit.Learner(model, train_loader, val_loader=val_loader)
# estimate LR using Learning Rate Finder
learner.find_lr()
# train using 1cycle learning rate policy
learner.fit_onecycle(1e-4, 3)
# plot training vs. validation loss
learner.plot('loss')
# make predictions as easy as in Keras
y_pred = learner.predict(test_loader)
# save model and reload later
learner.save('/tmp/mymodel')
learer.load('/tmp/mymodel')
For more information, see the Tutorial Notebook.
Installation
After ensuring PyTorch is installed, install TorchFit
with:
pip3 install torchfit
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchfit-0.2.0.tar.gz
(8.6 kB
view hashes)