Skip to main content

PyTorch module trainer

Project description

PyTorch trainer

CircleCI

Are you tired of writing those same epoch and data-loader loops to train your PyTorch module ? Look no further, PyTorch trainer is a library that hides all those boring training lines of code that should be native to PyTorch.

You will also benefit from the following features:

  • Early stopping: stop training after a period of stagnation
  • Checkpointing: save model and estimator at regular intervals
  • CSV file writer to output logs
  • Several metrics are available: all default PyTorch loss functions, Accuracy, MAE
  • Progress bar from console
  • SIGINT handling: handle CTRL-C
  • Model's data type (float32, float64)

Example

Code examples can be found in the example folder.

Here is a simple example:

import torch
import pytorchtrainer as ptt


# Your usual model, optimizer, loss function and data loaders
model = MyModel()
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
train_loader = MyTrainDataloader()
validation_loader = MyValidationDataloader()


# instantiate a default trainer
trainer = ptt.create_default_trainer(model, optimizer, criterion)

# optionally save a checkpoint after every 10 epochs
trainer.register_post_epoch_callback(ptt.checkpoint.SaveCheckpointCallback(save_every=10))

# optionally compute validation loss after every epoch
validation_callback = ptt.callback.ValidationCallback(validation_loader, ptt.metric.TorchLoss(criterion), validate_every=1)
trainer.register_post_epoch_callback(validation_callback)

# optionally save training and validation loss after every iteration using default save directory
trainer.register_post_iteration_callback(ptt.callback.CsvWriter(save_every=1,
                                                                extra_header=[validation_callback.state_attribute_name],
                                                                callback=lambda state: [state.get(validation_callback.state_attribute_name)]))
# run the training
trainer.train(train_loader, max_epochs=100)

Dependencies

  • python > 3.5
  • pytorch > 1.0.0 (install instructions from the official PyTorch website)

Contributing

Feel free to submit an issue or pull request. But before you do please read the contributing guidelines

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorchtrainer-0.2.1.tar.gz (14.0 kB view details)

Uploaded Source

Built Distribution

pytorchtrainer-0.2.1-py3-none-any.whl (15.3 kB view details)

Uploaded Python 3

File details

Details for the file pytorchtrainer-0.2.1.tar.gz.

File metadata

  • Download URL: pytorchtrainer-0.2.1.tar.gz
  • Upload date:
  • Size: 14.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.5.7

File hashes

Hashes for pytorchtrainer-0.2.1.tar.gz
Algorithm Hash digest
SHA256 28821a04052fe527d6a770e5cdb9eb70f31302c3a6e726bd7533c76c140fb47c
MD5 9f6b621e1c5dc10a3e817f0f1b41bfde
BLAKE2b-256 728c8c60ea5a271da20463c08a2df382371107f52be5b69a21c07b3eb063c6f8

See more details on using hashes here.

File details

Details for the file pytorchtrainer-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: pytorchtrainer-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 15.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.14.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.35.0 CPython/3.5.7

File hashes

Hashes for pytorchtrainer-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 92b9319e9be77d801e31448f1b243fa7b9183b4babd170679de80271391d44a1
MD5 c45c019325f8659e2fef8f4dd68a52ed
BLAKE2b-256 fd15a213a07abd20f3f9063b16dcf50793e7823a963b75495132787fab4d5a28

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page