Skip to main content

A wrapper class for neural networks that makes working with them easier.

Project description

Description:

NetworkWrapper is a convenience class for working with neural networks using PyTorch. With it you can:

  • Train and retrain models
  • Display metrics (including class-specific ones)
  • Make beautiful visualizations of models in the form of histograms and graphs
  • Predict the probabilities of object labels or just labels. 5. Save the state of the model and optimizer by epoch
  • Save the best weights, with saving the best weights.

And everything is accompanied by beautiful and formatted output, working in 3 lines of code!

The class contains 650 lines of code, and it took almost a month to write. Everything - from paths, separators, depending on the type of OS, type of device for training, and right down to the display formats of progress bars, figsizes (without processing in PyCharm, the pbar of Jupiter/colab moves out, but in Jupiter/colab the pbar PyCharm is moving out) are done AUTOMATICALLY :)

And this class can do a lot more - incl. and further train the model from some epoch :) I started training for 20,50 or even 100 epochs and went for a walk/mind my own business - and then came and looked at all the statistics and loaded into the desired epoch (if you want). Or, using 1 line of code, I trimmed the saved model and optimizer weights for all epochs, starting with the desired one. Or even threw out all the era weights except the best one. And there is no need to restart training many times, fearing that the model will be overtrained or undertrained for the entered number of epochs. Just a fairy tale)

The best way to support my creation is to star the project on Github :)

GitHub: https://github.com/JohnConnor123/nn-wrapper

Contact email: ivan.eudokimoff2014@gmail.com

P.s. There may be minor bugs (and I tried very hard to avoid them and spent more than a week debugging the code). If you have a bug, open an issue on the project's Github or just write to me by email.

Quick usage guide

P.s. Contains only the basic possibilities of the class

Installing

First install the package using pip:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip3 install nn-wrapper

By default, without first command the library will install the non-CUDA version of torch.

Importing

P.s. Optional: we set the main parts of the path of paths in windows and colab, relative to which relative paths are specified.

from nn-wrapper import NetworkWrapper

main_windows_path = "D:\\Python_Projects\\Jupyter\\DL MIPT Stepik\\"
main_colab_path = r'/content/gdrive/MyDrive/Colab Notebooks/Deep Learning School/'
NetworkWrapper.set_main_paths(main_windows_path, main_colab_path)

Creating NetworkWrapper object

We create a NetworkWrapper object, wrapping any neural network model in it and passing all the parameters.

model_testing = NetworkWrapper(model=model, epochs=5, batch_size=32, num_workers=0,
                               train_dataset=train_dataset, val_dataset=val_dataset,
                               n_classes=n_classes, colab_view=False,
                               relative_path='Models\\Transfer learning\\efficientnet_b1.pth',
                               lr=1e-3, scheduler_gamma=0.9,
                               load_pretrained_model=True)

P.s. The optimizer, scheduler and criterion are not passed to the initializer - baselines from classification tasks are used:

self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.scheduler = ExponentialLR(self.optimizer, gamma=scheduler_gamma)
self.criterion = nn.CrossEntropyLoss()

But you can change objects by explicitly specifying them after initialization:

model_testing.optimizer = ...
model_testing.scheduler = ...
model_testing.criterion = ...

Initializing the model

We start initializing the model. The model is trained or loaded if a trained model is found. Then, by default, the main metrics are calculated - this can be controlled using the "calculate_metrics" parameter of the "train_load_model" method.

model_testing.train_load_model()

Printing and displaying all information after model initialization

print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

Loading a specific epoch

model_testing.load_epoch(epoch_to_load)

Trimming the save file

Removing model weights and optimizer weights after "last_untruncated_epoch" epoch:

model_testing.truncate_dump_file(last_untruncated_epoch=last_untruncated_epoch)

Renewal learning from a specific epoch

model_testing.resume_model_training(start_epoch=start_epoch, total_epochs=6,
                                    relative_path='Transfer learning\\resumed_trained.pth')

Removing all epochs from the dump file except the best epoch

model_testing.drop_all_epochs_from_dump_file_except_best_epoch()

Viewing and changing system settings

Additional feature: you can view and change most of the protected attributes that are responsible for various wrapper settings. And all this is done simply through a dot, without cluttering the namespace!

For example, you can change pyplot figsize:

model_testing.protected_attributes.figsize = (6, 4)

Example code to demonstrate the operation of the main functionality of the library:

from nn-wrapper import NetworkWrapper
main_windows_path = "D:\\Python_Projects\\Jupyter\\DL MIPT Stepik\\"
main_colab_path = r'/content/gdrive/MyDrive/Colab Notebooks/Deep Learning School/'
NetworkWrapper.set_main_paths(main_windows_path, main_colab_path)

model = models.efficientnet_b1(pretrained=True)
model.classifier[1] = nn.Linear(in_features=1280, out_features=n_classes)

model_testing = NetworkWrapper(model=model, epochs=5, batch_size=32, num_workers=0,
                               train_dataset=train_dataset, val_dataset=val_dataset,
                               n_classes=n_classes, colab_view=False,
                               relative_path='Models\\Transfer learning\\efficientnet_b1.pth',
                               lr=1e-3, scheduler_gamma=0.9,
                               load_pretrained_model=True)

model_testing.train_load_model()

print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

epoch_to_load = model_testing.total_epochs//2
print(f"\n\nLoading epoch #{epoch_to_load}")
model_testing.load_epoch(epoch_to_load)
print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

print("\nLoading best epoch")
model_testing.load_epoch(model_testing.best_epoch)
print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

last_untruncated_epoch = 3  # example
print(f"\nTruncate dump file. Last_untruncated_epoch: {last_untruncated_epoch}")
model_testing.truncate_dump_file(last_untruncated_epoch=last_untruncated_epoch)
print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

start_epoch = 2
model_testing.resume_model_training(start_epoch=start_epoch, total_epochs=6,
                                    relative_path='Transfer learning\\resumed_trained.pth')
print(f"Best epoch: {model_testing.best_epoch}    "
      f"Loaded epoch: {model_testing.loaded_epoch}    "
      f"Total epochs count: {model_testing.total_epochs}")
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

print("\nDrop all epochs from dump file except best epoch")
model_testing.drop_all_epochs_from_dump_file_except_best_epoch()
print(model_testing.get_metrics())
print(model_testing.get_metrics('stats_by_class'))
print("Unpredicted classes", set(model_testing.actual_labels) - set(model_testing.y_preds))
model_testing.plot_metrics()
model_testing.plot_correct_class_prediction_hist()
model_testing.plot_confidence_on_examples()

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nn_wrapper-1.0.0.tar.gz (12.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nn_wrapper-1.0.0-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file nn_wrapper-1.0.0.tar.gz.

File metadata

  • Download URL: nn_wrapper-1.0.0.tar.gz
  • Upload date:
  • Size: 12.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.2 CPython/3.11.7 Windows/10

File hashes

Hashes for nn_wrapper-1.0.0.tar.gz
Algorithm Hash digest
SHA256 da0cd86a7657a8371ccb638a8562d874d6ad0f11d612ee66887aeeee33d6ba3a
MD5 41d49bff796bf845663648cc930ff0c9
BLAKE2b-256 4274525895ead768f3a8c3fbb0239f1fd06ab5a6640ffd8f73a3744ed22f93b3

See more details on using hashes here.

File details

Details for the file nn_wrapper-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: nn_wrapper-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.2 CPython/3.11.7 Windows/10

File hashes

Hashes for nn_wrapper-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 e99bc3633c6ae3c8eb71f202a3ca303a55e7368542354bc762e31b4897c19bb5
MD5 1a371974218aa176711eea80bb75bddd
BLAKE2b-256 97152ae9cbe27c7094df5c8e34db4339bf95d70dcee558daec6241a8ffcc887e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page