Skip to main content

Resnet implementation in pytorch

Project description

torch-resnet

Unified torch implementation of resnets with or without pre-activation/width.

We implement (pre-act) resnets for ImageNet for each size described in [1] and [2] (18, 34, 50, 101, 152, 200). We also implement (pre-act) resnets for Cifar for each size described in [1] and [2] (20, 32, 44, 56, 110, 164, 1001, 1202).

Following what is used in the literature we also propose a version of the ImageNet resnets for Cifar10 (See SimClr paper https://arxiv.org/pdf/2002.05709.pdf). To adapt the architectures to smaller images (32x32), the initial layers conv 7x7 and max_pool are replaced by a simple conv 3x3. The other layers/parameters are kept, resulting in a variant of wide-resnets for Cifar.

Finally, we implement wide resnets (with or without pre-activation) for Cifar and ImageNet following [3].

Additional models can easily be created using the default class ResNet or PreActResNet. It is also possible to create your own block following the same model as those implemented.

We use by default projection shortcuts whenever they are required (option B from [1]) but we have also implemented option A (IdentityShortcut) and C (FullProjectionShortcut), and more can be added following the same template. For instance we introduce our own shortcut: ConvolutionShortcut. It does a 3x3 convolution on the shortcut path when dimensions do not match (vs the 1x1 conv of the ProjectionShortcut).

We have validated our implementation by testing it on Cifar10/Cifar100 (See Results).

Install

$ pip install torch-resnet

Getting started

import torch

import torch_resnet
from torch_resnet.utils import count_layer, count_parameters

model = torch_resnet.PreActResNet50()  # Build a backbone Resnet50 with pre-activation for ImageNet
model.set_head(nn.Linear(model.out_planes, 1000))  # Set a final linear head

count_layers(model)  # -> 54 (In the original paper they do not count shortcut/downsampling layers)
count_parameters(model) / 10**6  # Nb. parameters in millions

out = model(torch.randn(1, 3, 224, 224))

See example/example.py for a more complete example.

Results

Results obtained with example/example.py following closely papers indications. Most are reported as mean $\pm$ std (on 5 runs with different seed). If a single number is reported, only a single run has been done due to computational time. All training are done with Automatique Mixed Precision.

For all resnets and pre-act resnets, the learning rate is scheduled following [1] and [2] with a warm-up in the 400 first iterations at 0.01. The initial learning rate is then set at 0.1 and decreased by 10 at 32k and 48k iterations. Training is stopped after 160 epochs (~62.5k iterations).

Contrary to [1], the training set is not split in 45k/5k to perform validation, and we directly took the final model to evaluate the performances on the test set (No validation is done). Also on Cifar [1] is using option A (Identity shortcut), whereas we use the default option B (Projection shortcut when required).

Follow the link to access the training curves for each model and each seed used (111, 222, 333, 444, 555). (Upload in coming...)

Model Params Cifar10 Cifar10 (paper) Cifar100 Cifar100 (paper)
ResNet20 0.3M 8.64 $\pm$ 0.16 8.75 [1] 33.23 $\pm$ 0.32 xxx
ResNet32 0.5M 7.64 $\pm$ 0.23 7.51 [1] 31.64 $\pm$ 0.54 xxx
ResNet44 0.7M 7.47 $\pm$ 0.19 7.17 [1] 30.88 $\pm$ 0.22 xxx
ResNet56 0.9M 7.04 $\pm$ 0.26 6.97 [1] 30.15 $\pm$ 0.29 xxx
ResNet110 1.7M 6.60 $\pm$ 0.09 6.61 $\pm$ 0.16 [1] 28.99 $\pm$ 0.22 xxx
ResNet164 1.7M 5.97 $\pm$ 0.20 xxx 25.79 $\pm$ 0.51 25.16 [2]
ResNet1001* 10.3M 7.95 xxx 29.94 27.82 [2]
ResNet1202 19.4M 7.90 7.93 [1] 33.20 xxx
PreActResNet20 0.3M 8.61 $\pm$ 0.23 xxx 33.40 $\pm$ 0.30 xxx
PreActResNet32 0.5M 7.76 $\pm$ 0.10 xxx 32.02 $\pm$ 0.27 xxx
PreActResNet44 0.7M 7.63 $\pm$ 0.10 xxx 30.78 $\pm$ 0.17 xxx
PreActResNet56 0.9M 7.42 $\pm$ 0.13 xxx 30.18 $\pm$ 0.39 xxx
PreActResNet110 1.7M 6.79 $\pm$ 0.12 xxx 28.45 $\pm$ 0.25 xxx
PreActResNet164 1.7M 5.61 $\pm$ 0.16 5.46 [2] 25.23 $\pm$ 0.21 24.33 [2]
PreActResNet1001 10.3M 4.92 4.89 $\pm$ 0.14 [2] 23.18 22.68 $\pm$ 0.22 [2]
PreActResNet1202 19.4M 6.66 xxx 27.65 xxx
ResNet18-small 11.2M 5.88 $\pm$ 0.15 xxx 26.74 $\pm$ 0.42 xxx
ResNet34-small 21.3M 5.50 $\pm$ 0.17 xxx 25.34 $\pm$ 0.29 xxx
ResNet50-small 23.5M 5.86 $\pm$ 0.30 xxx 25.20 $\pm$ 0.89 xxx
ResNet101-small 42.5M 5.45 $\pm$ 0.14 xxx 23.93 $\pm$ 0.56 xxx
PreActResNet18-small 11.2M 5.65 $\pm$ 0.12 xxx 25.46 $\pm$ 0.34 xxx
PreActResNet34-small 21.3M 5.29 $\pm$ 0.17 xxx 24.75 $\pm$ 0.31 xxx
PreActResNet50-small 23.5M 5.83 $\pm$ 0.47 xxx 23.97 $\pm$ 0.36 xxx
PreActResNet101-small 42.5M 5.18 $\pm$ 0.11 xxx 23.69 $\pm$ 0.41 xxx

* ResNet1001 cannot be trained with AMP (due to training instability) thus it was trained without AMP. Also, please note that AMP usually leads to slightly worst performances, therefore most of our results here are probably underestimated.

Note that in [2] and in most github implementation, the test set is used as a validation set (taking the max acc reached on it as the final result, as done in the official implem [2]), obviously leading to falsely better performances. When dropping AMP and taking the max value rather than the last value, we also reach better performances. (We only tested on PreActResNet164 and PreActResNet1001 where results where slighlty behind the paper).

Model (No Amp, best) Params Cifar100 Cifar100 (paper)
PreActResNet164 1.7M 24.83 $\pm$ 0.16 24.33 [2]
PreActResNet1001 10.3M 22.86 22.68 $\pm$ 0.22 [2]

We quickly tried our implementations for shortcuts (with AMP and last model evalutation)

Model Params Cifar10 Cifar100
ResNet20 (Proj) 0.27M 8.64 $\pm$ 0.16 xxx
ResNet20-Id 0.27M 8.65 $\pm$ 0.08 xxx
ResNet20-FullProj 0.28M 8.22 $\pm$ 0.14 xxx
ResNet20-Conv 0.29M 8.41 $\pm$ 0.19 xxx
PreActResNet164 (Proj) 1.70M 5.61 $\pm$ 0.16 25.23 $\pm$ 0.21
PreActResNet164-Id 1.68M 5.52 $\pm$ 0.14 24.71 $\pm$ 0.12
PreActResNet164-FullProj 3.19M Failed (90.0) Failed (99.0)
PreActResNet164-Conv 2.06M 5.55 $\pm$ 0.18 23.86 $\pm$ 0.16

More works are needed to fully investigate shortcuts but intuitevely and from the few experiments we've done, it seems that they all work correctly. FullProjectionShortcut should not be used as it increases instability (no more true shortcuts) when training. The introduced convolutional shortcut (3x3 conv instead of 1x1) seems to help on Cifar100. Finally it seems that with Identity shortcut for PreActResNet164 (+ No Amp and best model evaluation) as in [2], we would reach around the same performances of 24.33 on Cifar100.

References

Build and Deploy

$ python -m build
$ python -m twine upload dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-resnet-0.0.2.dev2.tar.gz (15.6 kB view hashes)

Uploaded Source

Built Distribution

torch_resnet-0.0.2.dev2-py3-none-any.whl (15.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page