Skip to main content

PyTorch utility APIs.

Project description

PyPI Build Status Code Coverage Release Date License Downloads

TorchUtils is a Python package providing helpful utility APIs for your PyTorch projects.

Features

Requirements

  • PyTorch >= 1.0.0

  • Numpy >= 1.16.2

  • Matplotlib >= 3.0.3

Installation

PyPi:

$ pip install torchutils

Conda:

$ conda install -c sahni torchutils

Documentation

Detailed API documentation is available here.

Examples

Checkpoint:

import torchvision
import torchutils as tu
import torch.optim as optim

model = torchvision.models.alexnet()
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.1)
print('Original learning rate:', tu.get_lr(optimizer))

# load checkpoint
start_epoch = tu.load_checkpoint(model_path='.',
                        ckpt_name='model_20190814-212442_e0_0.7531.pt',
                        model=model, optimizer=optimizer,
                        scheduler=scheduler)

print('Checkpoint learning rate:', tu.get_lr(optimizer))
print('Start from epoch:', start_epoch)

Output

Original learning rate: 0.001
Checkpoint learning rate: 0.1234
Start epoch: 1

Statistics:

import torch
import torchutils as tu

# define your dataset and dataloader
dataset = MyDataset()
trainloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                          num_workers=1,
                                          shuffle=False)

# get statistics
stats = tu.get_dataset_stats(trainloader, verbose=True)
print('Mean:', stats['mean'])
print('Std:', stats['std'])

Output

Calculating dataset stats...
Batch 100/100
Mean: tensor([10000.0098,  9999.9795,  9999.9893])
Std: tensor([0.9969, 1.0003, 0.9972])

Learning Rate:

import torchvision
import torchutils as tu
import torch.optim as optim

model = torchvision.models.alexnet()
optimizer = optim.Adam(model.parameters())

# get learning rate
current_lr = tu.get_lr(optimizer)
print('Current learning rate:', current_lr)

# set learning rate
optimizer = tu.set_lr(optimizer, current_lr*0.1)
revised_lr = tu.get_lr(optimizer)
print('Revised learning rate:', revised_lr)

Output

Current learning rate: 0.001
Revised learning rate: 0.0001

Evaluation Metrics:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchutils as tu

# define your network
model = MyNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
trainset = torchvision.datasets.MNIST(root='./data/', train=True,
                                    download=True,
                                    transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=60,
                                        shuffle=True, num_workers=2,
                                        drop_last=True)
n_epochs = 1
model.train()
for epoch in range(n_epochs):
    print('Epoch: %d/%d' % (epoch + 1, n_epochs))
    # define loss tracker
    loss_tracker = tu.RunningLoss()
    for batch_idx, (data, target) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        # update loss tracker with latest loss
        loss_tracker.update(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            # easily print latest and average loss
            print(loss_tracker)

Output

Epoch: 1/1
Loss - Val: 2.2921 Avg: 2.2921
Loss - Val: 0.5084 Avg: 0.9639
Loss - Val: 0.6027 Avg: 0.6588
Loss - Val: 0.1817 Avg: 0.5255
Loss - Val: 0.1005 Avg: 0.4493
Loss - Val: 0.2982 Avg: 0.3984
Loss - Val: 0.3103 Avg: 0.3615
Loss - Val: 0.0940 Avg: 0.3296
Loss - Val: 0.0957 Avg: 0.3071
Loss - Val: 0.0229 Avg: 0.2875

Model Summary:

import torch
import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
# easily print model summary
tu.get_model_summary(model, torch.rand((1, 3, 224, 224)))

Output

=========================================================================================
Layer                           Kernel             Output          Params           FLOPs
=========================================================================================
0_features.Conv2d_0         [3, 64, 11, 11]    [1, 64, 55, 55]       23,296    70,470,400
1_features.ReLU_1                         -    [1, 64, 55, 55]            0             0
2_features.MaxPool2d_2                    -    [1, 64, 27, 27]            0             0
3_features.Conv2d_3         [64, 192, 5, 5]   [1, 192, 27, 27]      307,392   224,088,768
4_features.ReLU_4                         -   [1, 192, 27, 27]            0             0
5_features.MaxPool2d_5                    -   [1, 192, 13, 13]            0             0
6_features.Conv2d_6        [192, 384, 3, 3]   [1, 384, 13, 13]      663,936   112,205,184
7_features.ReLU_7                         -   [1, 384, 13, 13]            0             0
8_features.Conv2d_8        [384, 256, 3, 3]   [1, 256, 13, 13]      884,992   149,563,648
9_features.ReLU_9                         -   [1, 256, 13, 13]            0             0
10_features.Conv2d_10      [256, 256, 3, 3]   [1, 256, 13, 13]      590,080    99,723,520
11_features.ReLU_11                       -   [1, 256, 13, 13]            0             0
12_features.MaxPool2d_12                  -     [1, 256, 6, 6]            0             0
13_classifier.Dropout_0                   -          [1, 9216]            0             0
14_classifier.Linear_1         [9216, 4096]          [1, 4096]   37,752,832    75,493,376
15_classifier.ReLU_2                      -          [1, 4096]            0             0
16_classifier.Dropout_3                   -          [1, 4096]            0             0
17_classifier.Linear_4         [4096, 4096]          [1, 4096]   16,781,312    33,550,336
18_classifier.ReLU_5                      -          [1, 4096]            0             0
19_classifier.Linear_6         [4096, 1000]          [1, 1000]    4,097,000     8,191,000
=========================================================================================
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
Total FLOPs: 773,286,232 / 773.29 MFLOPs
-----------------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 8.31
Params size (MB): 233.08
Estimated Total Size (MB): 241.96
=========================================================================================

Model FLOPs:

import torch
import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
# calculate model FLOPs
total_flops = tu.get_model_flops(model, torch.rand((1, 3, 224, 224)))
print('Total model FLOPs: {:,}'.format(total_flops))

Output

Total model FLOPs: 773,304,664

Model Parameters:

import torchvision
import torchutils as tu

model = torchvision.models.alexnet()
# calculate total model parameters
total_params = tu.get_model_param_count(model)
print('Total model params: {:,}'.format(total_params))

Output

Total model params: 61,100,840

Random Seed:

import torchutils as tu

# set numpy, torch and cuda seed
tu.set_random_seed(2222)

Gradient Flow:

import torch
import torchvision
import torchutils as tu

criterion = torch.nn.CrossEntropyLoss()
net = torchvision.models.alexnet(num_classes=10)
out = net(torch.rand(1, 3, 224, 224))
ground_truth = torch.randint(0, 10, (1, ))
loss = criterion(out, ground_truth)
loss.backward()

# save model gradient flow to image
tu.plot_gradients(net, './grad_figures/grad_01.png', plot_type='line')

Saved File

Example Gradient Flow

License

TorchUtils is distributed under the MIT license, see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchutils-0.0.4.tar.gz (19.6 kB view hashes)

Uploaded Source

Built Distribution

torchutils-0.0.4-py3-none-any.whl (20.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page