Skip to main content

Image segmentation models with pre-trained backbones with Keras.

Project description

Segmentation models Zoo

Segmentation models with pretrained backbones

Unet and FPN like models

Backbone model Name Weights UNet FPN
VGG16 vgg16 imagenet + +
VGG19 vgg19 imagenet + +
ResNet18 resnet18 imagenet + +
ResNet34 resnet34 imagenet + +
ResNet50 resnet50 imagenet
imagenet11k-places365ch
+ +
ResNet101 resnet101 imagenet + +
ResNet152 resnet152 imagenet
imagenet11k
+ +
ResNeXt50 resnext50 imagenet + +
ResNeXt101 resnext101 imagenet + +
DenseNet121 densenet121 imagenet + +
DenseNet169 densenet169 imagenet + +
DenseNet201 densenet201 imagenet + +
Inception V3 inceptionv3 imagenet + +
Inception ResNet V2 inceptionresnetv2 imagenet + +

Installation

  1. Clone repositoriy to your project
$ git clone https://github.com/qubvel/segmentation_models.git
  1. Update submodules
$ cd segmentation_models
$ git submodule update --init --recursive

Code examples

Train Unet model:

from segmentation_models import Unet

# prepare data
x, y = ...

# prepare model
model = Unet(backbone_name='resnet34', encoder_weigths='imagenet')
model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])

# train model
model.fit(x, y)

Train FPN model:

from segmentation_models import FPN

model = FPN(backbone_name='resnet34', encoder_weigths='imagenet')

Useful trick

Freeze encoder weights for fine-tuning during first epochs of training:

from segmentation_models import FPN
from segmentation_models.utils import set_trainable

model = FPN(backbone_name='resnet34', encoder_weigths='imagenet', freeze_encoder=True)
model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])

# pretrain model decoder
model.fit(x, y, epochs=2)

# release all layers for training
set_trainable(model) # set all layers trainable and recompile model

# continue training
model.fit(x, y, epochs=100)

TODO

  • Update Unet API
  • Update FPN API
  • Add Linknet models
  • Add PSP models
  • Add DPN backbones

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

segmentation_models-0.1.0.tar.gz (20.9 kB view hashes)

Uploaded Source

Built Distribution

segmentation_models-0.1.0-py2.py3-none-any.whl (33.6 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page