PyTorch utility functions.
Project description
TorchUtils is a Python package providing helpful utility functions for your PyTorch projects. TorchUtils helps speed up PyTorch development by implementing trivial but useful functionality so that you don’t have to.
Key Features
Calculate total model parameters.
Calculate model FLOPs.
Print model summary in Keras style.
Save/load checkpoints.
Get/set learning rate.
Set random seed.
Examples
Model Summary:
import torch import torchvision import torchutils as tu model = torchvision.models.alexnet() tu.get_model_summary(model, torch.rand((1, 3, 224, 224))) # Output ========================================================================================= Layer Kernel Output Params FLOPs ========================================================================================= 0_features.Conv2d_0 [3, 64, 11, 11] [1, 64, 55, 55] 23,296 70,470,400 1_features.ReLU_1 - [1, 64, 55, 55] 0 0 2_features.MaxPool2d_2 - [1, 64, 27, 27] 0 0 3_features.Conv2d_3 [64, 192, 5, 5] [1, 192, 27, 27] 307,392 224,088,768 4_features.ReLU_4 - [1, 192, 27, 27] 0 0 5_features.MaxPool2d_5 - [1, 192, 13, 13] 0 0 6_features.Conv2d_6 [192, 384, 3, 3] [1, 384, 13, 13] 663,936 112,205,184 7_features.ReLU_7 - [1, 384, 13, 13] 0 0 8_features.Conv2d_8 [384, 256, 3, 3] [1, 256, 13, 13] 884,992 149,563,648 9_features.ReLU_9 - [1, 256, 13, 13] 0 0 10_features.Conv2d_10 [256, 256, 3, 3] [1, 256, 13, 13] 590,080 99,723,520 11_features.ReLU_11 - [1, 256, 13, 13] 0 0 12_features.MaxPool2d_12 - [1, 256, 6, 6] 0 0 13_classifier.Dropout_0 - [1, 9216] 0 0 14_classifier.Linear_1 [9216, 4096] [1, 4096] 37,752,832 75,493,376 15_classifier.ReLU_2 - [1, 4096] 0 0 16_classifier.Dropout_3 - [1, 4096] 0 0 17_classifier.Linear_4 [4096, 4096] [1, 4096] 16,781,312 33,550,336 18_classifier.ReLU_5 - [1, 4096] 0 0 19_classifier.Linear_6 [4096, 1000] [1, 1000] 4,097,000 8,191,000 ========================================================================================= Total params: 61,100,840 Trainable params: 61,100,840 Non-trainable params: 0 Total FLOPs: 773,286,232 / 773.29 MFLOPs ----------------------------------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 8.31 Params size (MB): 233.08 Estimated Total Size (MB): 241.96 =========================================================================================
Learning Rate:
import torchvision import torchutils as tu import torch.optim as optim model = torchvision.models.alexnet() optimizer = optim.Adam(model.parameters()) current_lr = tu.get_lr(optimizer) print('Current learning rate:', current_lr) optimizer = tu.set_lr(optimizer, current_lr*0.1) revised_lr = tu.get_lr(optimizer) print('Revised learning rate:', revised_lr) # Output Current learning rate: 0.001 Revised learning rate: 0.0001
Checkpoint:
import torchvision import torchutils as tu import torch.optim as optim model = torchvision.models.alexnet() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1) print('Original learning rate:', tu.get_lr(optimizer)) # load checkpoint model_20190814-212442_e0_0.7531.pt start_epoch = tu.load_checkpoint(model_path='.', ckpt_name='model_20190814-212442_e0_0.7531.pt', model=model, optimizer=optimizer, scheduler=scheduler) print('Checkpoint learning rate:', tu.get_lr(optimizer)) print('Start from epoch:', start_epoch) # Output Original learning rate: 0.001 Checkpoint learning rate: 0.1234 Start epoch: 1
Requirements
Numpy >= 1.16.2
PyTorch >= 1.0.0
Installation
$ pip install torchutils
Documentation
API documentation is available at: https://anjandeepsahni.github.io/torchutils/
License
TorchUtils is distributed under the MIT license, see LICENSE.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
torchutils-0.0.2.tar.gz
(10.5 kB
view hashes)
Built Distribution
torchutils-0.0.2-py3-none-any.whl
(14.3 kB
view hashes)
Close
Hashes for torchutils-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c6bfac7a2a0c7d44982b6c946c4df5f778a99f560fef73ec4cd1980b7294501c |
|
MD5 | ba08ab71d6ec36c04b4750f5176216d1 |
|
BLAKE2b-256 | 5649e3ca846d4c7862d79200dd9a03fee40a023e1c83f7bb73aa1fde1ec46920 |