Skip to main content

A simple yet powerful CNN trainer for PyTorch and Lightning.

Project description

PyTorch CNN Trainer

Check Formatting Code style: black CI Tests PyPi Release Deploy mkdocs Install Package

Train CNN for your task

A simple engine to fine tune CNNs from torchvision and Pytorch Image models from Ross Wightman.

Example

Why This Package ?

It is very annoying to write training loop and training code for CNN training. Also to support all the training features it takes massive time.

Usually we don't need distributed training and it is very uncomfortable to use argparse and get the job done.

This simplifies the training. It provide you a powerful engine.py which can do lot of training functionalities. Also a dataset.py to load dataset in common scenarios.

Note: - Pytorch Trainer is not a distributed training script.

It will work good for single GPU machine for Google Colab / Kaggle.

But for distributed Training you can use the PyTorch Lightning Trainer (soon).

It will train on multiple GPUs just the way lightning supports (soon).

To install

Install torch and torchvision from PyTorch Run the following in terminal, I will release on PyPI soon.

! pip install -q git+git://github.com/oke-aditya/pytorch_cnn_trainer.git

Docs: -

I have provided some example of how to use this trainer in multiple scenarios. Please check examples folder. Some examples that are provided.

  • Fine Tune Torchvision models using fit().
  • Fine Tune Torchvision models using train_step() and validation_step().
  • Fine Tune Ross Wightman's models.
  • Quantization Aware Training head only.
  • Quantization Aware Training Fully.
  • Mixed Precision Training
  • Training with Stochastic Weighted Average (SWA)
  • LR Finder implementation

Features: -

  • Support PyTorch image models (timm) training and transfer learning.
  • Quantization Aware training example.
  • Early stopping with patience.
  • Support torchvision models trainging and transfer learning.
  • Support torchvision quantized models transfer learning.
  • Support for Mixed Precision Training.
  • L2 Norm Gradient Penalty.
  • LR Finder Implementation.
  • SWA Stochastic weighted Averaging support for training.
  • Add Keras Like fit method.
  • Sanity Check method.

Hope this repo helps people to train models using transfer learning.

If you like it do give * and tell people about it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_cnn_trainer-0.2.0rc1.tar.gz (16.5 kB view details)

Uploaded Source

Built Distribution

pytorch_cnn_trainer-0.2.0rc1-py3-none-any.whl (24.7 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_cnn_trainer-0.2.0rc1.tar.gz.

File metadata

  • Download URL: pytorch_cnn_trainer-0.2.0rc1.tar.gz
  • Upload date:
  • Size: 16.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.3.1 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for pytorch_cnn_trainer-0.2.0rc1.tar.gz
Algorithm Hash digest
SHA256 cda1e84013b3e9b192846a373ed718462fea522ca1875afce81aceb269f316db
MD5 ae1910ea3459d97690c79cc5799d016c
BLAKE2b-256 7c0b6b9ccf91dd7e200b27c43a8930a47ec246d9c022155a960a38402c8b6778

See more details on using hashes here.

File details

Details for the file pytorch_cnn_trainer-0.2.0rc1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_cnn_trainer-0.2.0rc1-py3-none-any.whl
  • Upload date:
  • Size: 24.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.3.1 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.8.5

File hashes

Hashes for pytorch_cnn_trainer-0.2.0rc1-py3-none-any.whl
Algorithm Hash digest
SHA256 35f96de757c84ede763195f034c85efb96569e7a94d250ca67f3175bb25e4fa7
MD5 66469c544ddc74389fa8dfa96ab5a4f2
BLAKE2b-256 ba64bc73cbf90487260e4bcacd0ccba6070d19315c88f2d85d15cc56eae8fac8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page