Skip to main content

Make pytorch implementation easy, so that one can focus on what is really important!!!!

Project description

SamPytorchHelper

A Pytorch package that can be used directly with pytorch and tensorboard to simplify the training process while improving the code readability. One no longer needs to explicitly write the training loop, saving the model and upload the training information to tensorboard.

Installation

Firstly, we need to install all the required libraries within the requirements.txt file. One way to do that is by pip installing the libraries using this command:

pip install -r ./requirements.txt

Secondly, install the SamPytorchHelper package from PyPI:

pip install SamPytorchHelper

When it is done, we can import the class form the package by:

from SamPytorchHelper.TorchHelper import TorchHelperClass

the TorchHelper module contains the TorchHelperClass that is used to train the network.

TorchHelperClass description

TorchHelperClass contains mainly 4 attributes and 2 methods.

  1. Attributes

    class TorchHelperClass:
        def __init__(self, model, loss_function, optimizer, comment='')
    

    At initialization, the class should receive 4 arguments: the network, the loss function, the optimizer, and a comment string. The comment should contain the information needed when the trained model will be saved and when it will be displayed in tensorboard. It provides an extra-information to model for better identification. the comment string may include the hyperparameter values such as epoch, lr, batch, etc.

  2. Methods

    The TorchHelperClass has mainly 2 methods:

    • train_model:
      def train_model(self, train_dataloader, val_dataloader, num_epoch=50, iter_print=100):
         """
         :param train_dataloader: training set
         :param val_dataloader: validation set
         :param num_epoch: the total number of epochs. default = 50
         :param iter_print: indicate when to print the loss after how many iteration. default = 100
         :return: current trained model
         """
      
    • save_model::
      def save_model(self, path):
         """
         :param path: folder where to save the model
         :return: None
         """
      

Example

A complete example implementation can be found in test folder:

  • data folder: contains the FashionMNIST dataset downloaded using torchvision
  • runs folder: contains the information used by tensorboard.
  • trained_models folder: it has the trained model saved after training using the save_model method.
  • test.py: it used to test the 'TorchHelperClass', most importantly, it shows the steps on how we can use the package more efficiently to train our network:
      ...
      # hyper-parameters
      parameters = dict(
          lr=[0.01, 0.001],
          batch=[32, 64, 128],
          shuffle=[True],
          epochs=[10, 20],
          momentum=[0.9]
      )
      ...
      param_values = [v for v in parameters.values()]
      for id, (lr, batch, shuffle, epochs, momentum) in enumerate(product(*param_values)):
          print("Current Hyperparams id:", id+1)
          train_dataloader = DataLoader(train_data, batch_size=batch, shuffle=shuffle)
          test_dataloader = DataLoader(test_data, batch_size=batch, shuffle=False)
    
          net = Network()
    
          criterion = nn.CrossEntropyLoss()
          optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    
          comment = f' epch={epochs} lr={lr} bch={batch}'
          helper = TorchHelperClass(model=net, loss_function=criterion, optimizer=optimizer, comment=comment)
          helper.train_model(train_dataloader, test_dataloader, epochs, 1000)
          helper.save_model('trained_models')
          print()
      ...
    

The parameters is dict containing all the different value of each hyper-parameters. It is used for network hyper-parameters tuning where each parameter can have one or a list of values.

Results after running test.py

From the Terminal

terminal

From Tensorboard

tensorboard

loss training validation

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

SamPytorchHelper-0.1.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

SamPytorchHelper-0.1.0-py3-none-any.whl (5.5 kB view details)

Uploaded Python 3

File details

Details for the file SamPytorchHelper-0.1.0.tar.gz.

File metadata

  • Download URL: SamPytorchHelper-0.1.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for SamPytorchHelper-0.1.0.tar.gz
Algorithm Hash digest
SHA256 c089c09ff27e7b88a363d9c638b2f2369175742f332d388f7b1b04c2d61b0588
MD5 437245f1587c6638153f6bfc79ce5ae0
BLAKE2b-256 f792c7e7b0b8d121bf750eff7703e6bdc99fd1cd6783db20dfd7e47c08b98904

See more details on using hashes here.

File details

Details for the file SamPytorchHelper-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: SamPytorchHelper-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 5.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for SamPytorchHelper-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2b38eba5bc1cc2eac0e0f77b110922f74cd6f2a316fe9b68ca6040295e59949f
MD5 4c4c13e5cf0bb69b80beb1123d484954
BLAKE2b-256 10d69ce08ac1460cbc0a9e89b29012eec3157ae5a357de91bf288fadd2779a9f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page