Fine-tune pretrained Convolutional Neural Networks with PyTorch
Project description
Fine-tune pretrained Convolutional Neural Networks with PyTorch.
Features
- Gives access to the most popular CNN architectures pretrained on ImageNet.
- Automatically replaces classifier on top of the network, which allows you to train a network with a dataset that has a different number of classes.
- Allows you to use images with any resolution (and not only the resolution that was used for training the original model on ImageNet).
- Allows adding a Dropout layer or a custom pooling layer.
Supported architectures and models
From the torchvision package:
- ResNet (
resnet18
,resnet34
,resnet50
,resnet101
,resnet152
) - ResNeXt (
resnext50_32x4d
,resnext101_32x8d
) - DenseNet (
densenet121
,densenet169
,densenet201
,densenet161
) - Inception v3 (
inception_v3
) - VGG (
vgg11
,vgg11_bn
,vgg13
,vgg13_bn
,vgg16
,vgg16_bn
,vgg19
,vgg19_bn
) - SqueezeNet (
squeezenet1_0
,squeezenet1_1
) - MobileNet V2 (
mobilenet_v2
) - ShuffleNet v2 (
shufflenet_v2_x0_5
,shufflenet_v2_x1_0
) - AlexNet (
alexnet
) - GoogLeNet (
googlenet
)
From the Pretrained models for PyTorch package:
- ResNeXt (
resnext101_32x4d
,resnext101_64x4d
) - NASNet-A Large (
nasnetalarge
) - NASNet-A Mobile (
nasnetamobile
) - Inception-ResNet v2 (
inceptionresnetv2
) - Dual Path Networks (
dpn68
,dpn68b
,dpn92
,dpn98
,dpn131
,dpn107
) - Inception v4 (
inception_v4
) - Xception (
xception
) - Squeeze-and-Excitation Networks (
senet154
,se_resnet50
,se_resnet101
,se_resnet152
,se_resnext50_32x4d
,se_resnext101_32x4d
) - PNASNet-5-Large (
pnasnet5large
) - PolyNet (
polynet
)
Requirements
- Python 3.5+
- PyTorch 1.1+
Installation
pip install cnn_finetune
Major changes:
Version 0.4
- Default value for
pretrained
argument inmake_model
is changed fromFalse
toTrue
. Now callmake_model('resnet18', num_classes=10)
is equal tomake_model('resnet18', num_classes=10, pretrained=True)
Example usage:
Make a model with ImageNet weights for 10 classes
from cnn_finetune import make_model
model = make_model('resnet18', num_classes=10, pretrained=True)
Make a model with Dropout
model = make_model('nasnetalarge', num_classes=10, pretrained=True, dropout_p=0.5)
Make a model with Global Max Pooling instead of Global Average Pooling
import torch.nn as nn
model = make_model('inceptionresnetv2', num_classes=10, pretrained=True, pool=nn.AdaptiveMaxPool2d(1))
Make a VGG16 model that takes images of size 256x256 pixels
VGG and AlexNet models use fully-connected layers, so you have to additionally pass the input size of images when constructing a new model. This information is needed to determine the input size of fully-connected layers.
model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256))
Make a VGG16 model that takes images of size 256x256 pixels and uses a custom classifier
import torch.nn as nn
def make_classifier(in_features, num_classes):
return nn.Sequential(
nn.Linear(in_features, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)
Show preprocessing that was used to train the original model on ImageNet
>> model = make_model('resnext101_64x4d', num_classes=10, pretrained=True)
>> print(model.original_model_info)
ModelInfo(input_space='RGB', input_size=[3, 224, 224], input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
>> print(model.original_model_info.mean)
[0.485, 0.456, 0.406]
CIFAR10 Example
See examples/cifar10.py file (requires PyTorch 1.1+).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
cnn_finetune-0.6.0.tar.gz
(11.2 kB
view details)
File details
Details for the file cnn_finetune-0.6.0.tar.gz
.
File metadata
- Download URL: cnn_finetune-0.6.0.tar.gz
- Upload date:
- Size: 11.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/1.13.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.0.1 requests-toolbelt/0.9.1 tqdm/4.32.2 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 557582afa0acfbdc93c4d18a34db0b34714086c1194f9731056fcb425a5937bd |
|
MD5 | 35b49ffbcfc63da9218576c2a76ecb84 |
|
BLAKE2b-256 | e46303a442d31401c43fc17a814f22bd7c39ab8f13f42a6b2467ca0d0d042b3a |