PyTorch module trainer
Project description
PyTorch trainer
Are you tired of writing those same epoch and data-loader loops to train your PyTorch module ? Look no further, PyTorch trainer is a library that hides all those boring training lines of code that should be native to PyTorch.
You will also benefit from the following features:
- Early stopping: stop training after a period of stagnation
- Checkpointing: save model and estimator at regular intervals
- CSV file writer to output logs
- Several metrics are available: all default PyTorch loss functions, Accuracy, MAE
- Progress bar from console
- SIGINT handling: handle CTRL-C
- Model's data type (
float32
,float64
)
Example
Code examples can be found in the example folder.
Here is a simple example:
import torch
import pytorchtrainer as ptt
# Your usual model, optimizer, loss function and data loaders
model = MyModel()
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
train_loader = MyTrainDataloader()
validation_loader = MyValidationDataloader()
# instantiate a default trainer
trainer = ptt.create_default_trainer(model, optimizer, criterion)
# optionally save a checkpoint after every 10 epochs
trainer.register_post_epoch_callback(ptt.checkpoint.SaveCheckpointCallback(save_every=10))
# optionally compute validation loss after every epoch
validation_callback = ptt.callback.ValidationCallback(validation_loader, ptt.metric.TorchLoss(criterion), validate_every=1)
trainer.register_post_epoch_callback(validation_callback)
# optionally save training and validation loss after every iteration using default save directory
trainer.register_post_iteration_callback(ptt.callback.CsvWriter(save_every=1,
extra_header=[validation_callback.state_attribute_name],
callback=lambda state: [state.get(validation_callback.state_attribute_name)]))
# run the training
trainer.train(train_loader, max_epochs=100)
Dependencies
- python > 3.5
- pytorch 1.0.1 (install instructions from the official PyTorch website)
Contributing
Feel free to submit an issue or pull request. But before you do please read the contributing guidelines
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
pytorchtrainer-0.1.0.tar.gz
(11.5 kB
view details)
Built Distribution
File details
Details for the file pytorchtrainer-0.1.0.tar.gz
.
File metadata
- Download URL: pytorchtrainer-0.1.0.tar.gz
- Upload date:
- Size: 11.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.5.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7ac6e037afc32805370b9113016f588155de8f80582438d40b52408bea05fd0 |
|
MD5 | a944f333bdf90117573eada8e80a8796 |
|
BLAKE2b-256 | e95e11decc1171256fcb303d32406ab033b63817f73156d78ec532638b3661de |
File details
Details for the file pytorchtrainer-0.1.0-py3-none-any.whl
.
File metadata
- Download URL: pytorchtrainer-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.5.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | d4fb87f011209bfde97b75f1b086ef8b7272b13a61f016e3861a8a24e60e650b |
|
MD5 | fc7e1ca706fe5a3ca18ddba4bab57c48 |
|
BLAKE2b-256 | df3e5d3101cb3436a5c64fa224cd586e3ce8e89148cf30257fd678ea4c35742d |