Image segmentation models with pre-trained backbones. PyTorch.
Project description
Segmentation models
Segmentation models is python library with Neural Networks for Image Segmentation based on PyTorch.
The main features of this library are:
- High level API (just two lines to create neural network)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 30 available encoders for each architecture
- All encoders have pre-trained weights for faster and better convergence
Table of content
Quick start
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = smp.Unet('resnet34', encoder_weights='imagenet')
Change number of output classes in the model:
model = smp.Unet('resnet34', classes=3, activation='softmax')
All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining:
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
Examples
- Training model for cars segmentation on CamVid dataset here.
- Training model with Catalyst (high-level framework for PyTorch) - here.
Models
Architectures
Encoders
Type | Encoder names |
---|---|
VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
DenseNet | densenet121, densenet169, densenet201, densenet161 |
DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
Inception | inceptionresnetv2 |
ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
SENet | senet154 |
Weights
Weights name | Encoder names |
---|---|
imagenet+5k | dpn68b, dpn92, dpn107 |
imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, inceptionresnetv2, resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d, senet154 |
resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
Models API
model.encoder
- pretrained backbone to extract features of different spatial resolutionmodel.decoder
- segmentation head, depends on models architecture (Unet
/Linknet
/PSPNet
/FPN
)model.activation
- output activation function, one ofsigmoid
,softmax
model.forward(x)
- sequentially passx
through model`s encoder and decoder (return logits!)model.predict(x)
- inference method, switch model to.eval()
mode, call.forward(x)
and apply activation function withtorch.no_grad()
Installation
PyPI version:
$ pip install segmentation-models-pytorch
Latest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
License
Project is distributed under MIT License
Run tests
$ docker build -f docker/Dockerfile.dev -t smp:dev .
$ docker run --rm smp:dev pytest -p no:cacheprovider
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file segmentation_models_pytorch-0.0.3.tar.gz
.
File metadata
- Download URL: segmentation_models_pytorch-0.0.3.tar.gz
- Upload date:
- Size: 16.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3f5d95d6adc595814797d3e531ff8dc6f63c4f35d5bb6886fb7569533f8d538a |
|
MD5 | e8a2e9b6a6c74098dab56cdd4aae96d8 |
|
BLAKE2b-256 | 43e7294488cbb0696e215f9c40ef31b82603915c7618cb956bb5e3402325e846 |
File details
Details for the file segmentation_models_pytorch-0.0.3-py3-none-any.whl
.
File metadata
- Download URL: segmentation_models_pytorch-0.0.3-py3-none-any.whl
- Upload date:
- Size: 27.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.2.0 requests-toolbelt/0.9.1 tqdm/4.36.1 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2c4c4f4d843c438813193eaaae6eb7fad1057cf2cd7e9490932e302a5ebfb99e |
|
MD5 | 1ac87e28f00df7ee8d538cc64e8527a1 |
|
BLAKE2b-256 | 20c667e9d555d41094988aaaf033b1d7e732a326a2ef41a15b81211b56e464ce |