Skip to main content

Wrapper for managing neural network training processes

Project description

HyperModule

HyperModule is a wrapper of functions for managing the training, validation, and testing processes of neural networks. With HyperModule, it is easier to monitor the progress of the training process by logging loss and validation accuracy. Additionally, HyperModule provides convenient functions for loading and saving pre-trained models, streamlining the entire process of working with neural networks.

Usage

Getting Started

A simpliest way to use hypermodule is to create instances of network, optimizer, and scheduler first and then loading them with a hypermodule:

# Example 1.
model = NN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

hm = HyperModule(model, criterion, optimizer, scheduler)

hm.train(train_dataloader, valid_dataloader, save_path, num_epochs=100)
hm.test(test_dataloader)

Partials

However, a more recommended approach is to assign optimizer and scheduler by partial functions, which save you from chaining model parameters, optimizer, and learning rate scheduler on your own.

# Example 2.
from .partials import optim, sched

hm  = Hypermodule(
    model = NN(),
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = optim("SGD", lr=0.01, momentum=0.9),
    scheduler = sched("ExponentialLR", gamma=0.9)
)

This is equivalent to Example 1.

The partial function optim/sched can take an existed optimizer/scheduler as its argument, generate another optimizer/scheduler instance with the new hyperparameters.

# Example 3.
model = NeuralNetwork()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

hm  = Hypermodule(
    model = NN(),
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = optim(optimizer, lr=0.005, momentum=0.99),
    scheduler = sched(scheduler, gamma=0.11)
)

Now the optimizer used in training is a SGD with learning rate 0.005 and momentum 0.99, and the scheduler is an ExponentialLR scheduler with gamma 0.11.

Hyperparameters

With partials, we can provide hyperparmeters in a dict apart from the optimizer and scheduler functions.

from .partials import optim, sched

hyperparams = {
  'lr': 0.01, 
  'momentum': 0.9,
  'gamma': 0.9
}

hm  = Hypermodule(
    model = NN(),
    criterion = torch.nn.CrossEntropyLoss(),
    optimizer = optim("SGD"),
    scheduler = sched("ExponentialLR"),
    hyperparams = hyperparams
)

Note this is equivalent to Example 1 and 2.

Class structure

  • __init__: The constructor of the class which takes the model, criterion, optimizer, scheduler, and hyperparameters as input arguments.
  • optimizer: A property that returns the optimizer used for training the model.
    • optimizer.setter: A method that sets the optimizer used for training the model, either by accepting an instance of torch.optim.Optimizer or by creating an optimizer based on the provided optimizer configuration.
  • scheduler: A property that returns the scheduler used for adjusting the learning rate of the optimizer during training.
    • scheduler.setter: A method that sets the scheduler used for adjusting the learning rate of the optimizer during training, either by accepting an instance of torch.optim.lr_scheduler.LRScheduler or by creating a scheduler based on the provided scheduler configuration.
  • hyperparams: A property that returns the hyperparameters used for optimizing the model.
    • hyperparams.setter: A method that sets the hyperparameters used for optimizing the model, either by accepting a dictionary of hyperparameters or by creating a new optimizer and scheduler based on the provided hyperparameters.
  • train: A method that trains the model using the provided training and validation data loaders for the specified number of epochs. It also saves the best model based on the lowest validation loss and returns the training and validation losses and accuracy.
    • _update: A method that updates the model parameters based on the provided input images and targets and the current loss.
    • _update_progress: A method that updates the training progress bar based on the current epoch and batch loss.
    • _update_scheduler: A method that updates the learning rate of the optimizer using the scheduler.
    • _update_history: A method that updates the training history with the current epoch's training and validation losses and accuracy.
    • _perform_validation: A method that validates the model using the provided validation data loader and loss function and returns the validation loss and accuracy.
  • validate: Evaluates the model on the given dataloader using the given criterion and returns the loss and accuracy.
  • test: Evaluates the model on the given test dataloader using the given criterion and returns the accuracy.
  • predict: Returns the model's predictions for the given dataloader, optionally applying softmax to the output.
  • save: Saves the model and training information to the given save path, or to the HyperModule's load_path attribute if none is given.

Loading and Saving

The information being loaded and saved in HyperModule is

  • state_dict of neural network
  • state_dict of optimizer
  • state_dict of scheduler
  • number of epochs that neural network has been trained
  • training loss in each epoch
  • validation accuracy in each epoch
  • testing accuracy

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hypermodule-0.1.2.tar.gz (9.7 kB view details)

Uploaded Source

Built Distribution

hypermodule-0.1.2-py3-none-any.whl (9.2 kB view details)

Uploaded Python 3

File details

Details for the file hypermodule-0.1.2.tar.gz.

File metadata

  • Download URL: hypermodule-0.1.2.tar.gz
  • Upload date:
  • Size: 9.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.4

File hashes

Hashes for hypermodule-0.1.2.tar.gz
Algorithm Hash digest
SHA256 d719e65e22844d1e8771c61032285c1eb999c4abb9fda4889ed6f0a541cfcb55
MD5 fb1b438d87e361540af5e845eb539816
BLAKE2b-256 08fef108f59a74fe02649563b128e0644ef6ae4df3658cc89e9930d97ee047c6

See more details on using hashes here.

File details

Details for the file hypermodule-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: hypermodule-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 9.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.4

File hashes

Hashes for hypermodule-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 44d73a7a772e1a41cb5ffcbc19208d734ed5a5cca0718c54afa08ce0ab21c57c
MD5 ce3ed9defdc8d309b697021e7fa5a616
BLAKE2b-256 0030bde65869809922e88435a722f676f3c3e913ba188671a73cd72ab8784695

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page