Skip to main content

An ResNet implements of PyTorch.

Project description

ResNet-PyTorch

Update (Feb 20, 2020)

The update is for ease of use and deployment.

It is also now incredibly simple to load a pretrained model with a new number of classes for transfer learning:

from resnet_pytorch import ResNet 
model = ResNet.from_pretrained('resnet18', num_classes=10)

Update (February 2, 2020)

This update allows you to use NVIDIA's Apex tool for accelerated training. By default choice hybrid training precision + dynamic loss amplified version, if you need to learn more and details about apex tools, please visit https://github.com/NVIDIA/apex.

Overview

This repository contains an op-for-op PyTorch reimplementation of Deep Residual Learning for Image Recognition.

The goal of this implementation is to be simple, highly extensible, and easy to integrate into your own projects. This implementation is a work in progress -- new features are currently being implemented.

At the moment, you can easily:

  • Load pretrained ResNet models
  • Use ResNet models for classification or feature extraction

Upcoming features: In the next few days, you will be able to:

  • Quickly finetune an ResNet on your own dataset
  • Export ResNet models for production

Table of contents

  1. About ResNet
  2. Installation
  3. Usage
  4. Contributing

About ResNet

If you're new to ResNets, here is an explanation straight from the official PyTorch implementation:

Resnet models were proposed in "Deep Residual Learning for Image Recognition". Here we have the 5 versions of resnet models, which contains 5, 34, 50, 101, 152 layers respectively. Detailed model architectures can be found in Table 1.

Installation

Install from pypi:

$ pip3 install resnet_pytorch

Install from source:

$ git clone https://github.com/Lornatang/ResNet-PyTorch.git
$ cd ResNet-PyTorch
$ pip3 install -e .

Usage

Loading pretrained models

Load an resnet18 network:

from resnet_pytorch import ResNet
model = ResNet.from_name("resnet18")

Load a pretrained resnet18:

from resnet_pytorch import ResNet
model = ResNet.from_pretrained("resnet18")

Their 1-crop error rates on imagenet dataset with pretrained models are listed below.

Model structure Top-1 error Top-5 error
resnet18 30.24 10.92
resnet34 26.70 8.58
resnet50 23.85 7.13
resnet101 22.63 6.44
resnet152 21.69 5.94

Option B of resnet-18/34/50/101/152 only uses projections to increase dimensions.

For results extending to the cifar10 dataset, see examples/cifar

Example: Classification

We assume that in your current directory, there is a img.jpg file and a labels_map.txt file (ImageNet class names). These are both included in examples/simple.

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

Here's a sample execution.

import json

import torch
import torchvision.transforms as transforms
from PIL import Image

from resnet_pytorch import ResNet 

# Open image
input_image = Image.open("img.jpg")

# Preprocess image
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

# Load class names
labels_map = json.load(open("labels_map.txt"))
labels_map = [labels_map[str(i)] for i in range(1000)]

# Classify with ResNet18
model = ResNet.from_pretrained("resnet18")
model.eval()

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to("cuda")
    model.to("cuda")

with torch.no_grad():
    logits = model(input_batch)
preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()

print("-----")
for idx in preds:
    label = labels_map[idx]
    prob = torch.softmax(logits, dim=1)[0, idx].item()
    print(f"{label:<75} ({prob * 100:.2f}%)")

Example: Feature Extraction

You can easily extract features with model.extract_features:

import torch
from resnet_pytorch import ResNet 
model = ResNet.from_pretrained('resnet18')

# ... image preprocessing as in the classification example ...
inputs = torch.randn(1, 3, 224, 224)
print(inputs.shape) # torch.Size([1, 3, 224, 224])

features = model.extract_features(inputs)
print(features.shape) # torch.Size([1, 512, 1, 1])

Example: Export to ONNX

Exporting to ONNX for deploying to production is now simple:

import torch 
from resnet_pytorch import ResNet 

model = ResNet.from_pretrained('resnet18')
dummy_input = torch.randn(16, 3, 224, 224)

torch.onnx.export(model, dummy_input, "demo.onnx", verbose=True)

Example: Visual

cd $REPO$/framework
sh start.sh

Then open the browser and type in the browser address http://127.0.0.1:10004/.

Enjoy it.

ImageNet

See examples/imagenet for details about evaluating on ImageNet.

Credit

Deep Residual Learning for Image Recognition

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun

Abstract

Deeper neural networks are more difficult to train. We present a residual learning framework to ease the training of networks that are substantially deeper than those used previously. We explicitly reformulate the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. We provide comprehensive empirical evidence showing that these residual networks are easier to optimize, and can gain accuracy from considerably increased depth. On the ImageNet dataset we evaluate residual nets with a depth of up to 152 layers—8× deeper than VGG nets [41] but still having lower complexity. An ensemble of these residual nets achieves 3.57% error on the ImageNet test set. This result won the 1st place on the ILSVRC 2015 classification task. We also present analysis on CIFAR-10 with 100 and 1000 layers. The depth of representations is of central importance for many visual recognition tasks. Solely due to our extremely deep representations, we obtain a 28% relative improvement on the COCO object detection dataset. Deep residual nets are foundations of our submissions to ILSVRC & COCO 2015 competitions1 , where we also won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation.

paper code

@article{He2015,
	author = {Kaiming He and Xiangyu Zhang and Shaoqing Ren and Jian Sun},
	title = {Deep Residual Learning for Image Recognition},
	journal = {arXiv preprint arXiv:1512.03385},
	year = {2015}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

resnet_pytorch-0.2.0.tar.gz (9.8 kB view details)

Uploaded Source

Built Distribution

resnet_pytorch-0.2.0-py2.py3-none-any.whl (13.9 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file resnet_pytorch-0.2.0.tar.gz.

File metadata

  • Download URL: resnet_pytorch-0.2.0.tar.gz
  • Upload date:
  • Size: 9.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.9

File hashes

Hashes for resnet_pytorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 ba8f228c847037cceaa8c0213c9c8bf0fd04c00f44687edb7cc636259f871315
MD5 7f41e4854157869ad19f68e27532890e
BLAKE2b-256 04ecc0608ca4737a69631a1c78e9ba834ced47113f4e7321afa329e5aa9ef97d

See more details on using hashes here.

File details

Details for the file resnet_pytorch-0.2.0-py2.py3-none-any.whl.

File metadata

  • Download URL: resnet_pytorch-0.2.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 13.9 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.6.9

File hashes

Hashes for resnet_pytorch-0.2.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 f95612bf4fedb89d54f3b9503889d1e4f9c1d68216ae51920d39d0d9eac3a01a
MD5 29afe215aa07750db64cfe0bf6840cfe
BLAKE2b-256 780863f61c49fba28416244c98a425ac180d3cbea15884c5d29fafd720ae89e6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page