Image segmentation models with pre-trained backbones. PyTorch. Adapted for deepflash2
Project description
Python library with Neural Networks for Image
Segmentation based on PyTorch.
The main features of this library are:
- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 113 available encoders (and 400+ encoders from timm)
- All encoders have pre-trained weights for faster and better convergence
- Popular metrics and losses for training routines
📚 Project Documentation 📚
Visit Read The Docs Project Page or read following README to know more about Segmentation Models Pytorch (SMP for short) library
📋 Table of content
- Quick start
- Examples
- Models
- Models API
- Installation
- Competitions won with the library
- Contributing
- Citing
- License
⏳ Quick start
1. Create your first Segmentation model with SMP
Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=1, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
- see table with available model architectures
- see table with available encoders and their corresponding weights
2. Configure data preprocessing
All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). It is not necessary in case you train the whole model, not only decoder.
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
Congratulations! You are done! Now you can train your model with your favorite framework!
💡 Examples
- Training model for pets binary segmentation with Pytorch-Lightning notebook and
- Training model for cars segmentation on CamVid dataset here.
- Training SMP model with Catalyst (high-level framework for PyTorch), TTAch (TTA library for PyTorch) and Albumentations (fast image augmentation library) - here
- Training SMP model with Pytorch-Lightning framework - here (clothes binary segmentation by @ternaus).
📦 Models
Architectures
- Unet [paper] [docs]
- Unet++ [paper] [docs]
- MAnet [paper] [docs]
- Linknet [paper] [docs]
- FPN [paper] [docs]
- PSPNet [paper] [docs]
- PAN [paper] [docs]
- DeepLabV3 [paper] [docs]
- DeepLabV3+ [paper] [docs]
Encoders
The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name
and encoder_weights
parameters).
ResNet
Encoder | Weights | Params, M |
---|---|---|
resnet18 | imagenet / ssl / swsl | 11M |
resnet34 | imagenet | 21M |
resnet50 | imagenet / ssl / swsl | 23M |
resnet101 | imagenet | 42M |
resnet152 | imagenet | 58M |
ResNeXt
Encoder | Weights | Params, M |
---|---|---|
resnext50_32x4d | imagenet / ssl / swsl | 22M |
resnext101_32x4d | ssl / swsl | 42M |
resnext101_32x8d | imagenet / instagram / ssl / swsl | 86M |
resnext101_32x16d | instagram / ssl / swsl | 191M |
resnext101_32x32d | 466M | |
resnext101_32x48d | 826M |
ResNeSt
Encoder | Weights | Params, M |
---|---|---|
timm-resnest14d | imagenet | 8M |
timm-resnest26d | imagenet | 15M |
timm-resnest50d | imagenet | 25M |
timm-resnest101e | imagenet | 46M |
timm-resnest200e | imagenet | 68M |
timm-resnest269e | imagenet | 108M |
timm-resnest50d_4s2x40d | imagenet | 28M |
timm-resnest50d_1s4x24d | imagenet | 23M |
Res2Ne(X)t
Encoder | Weights | Params, M |
---|---|---|
timm-res2net50_26w_4s | imagenet | 23M |
timm-res2net101_26w_4s | imagenet | 43M |
timm-res2net50_26w_6s | imagenet | 35M |
timm-res2net50_26w_8s | imagenet | 46M |
timm-res2net50_48w_2s | imagenet | 23M |
timm-res2net50_14w_8s | imagenet | 23M |
timm-res2next50 | imagenet | 22M |
RegNet(x/y)
Encoder | Weights | Params, M |
---|---|---|
timm-regnetx_002 | imagenet | 2M |
timm-regnetx_004 | imagenet | 4M |
timm-regnetx_006 | imagenet | 5M |
timm-regnetx_008 | imagenet | 6M |
timm-regnetx_016 | imagenet | 8M |
timm-regnetx_032 | imagenet | 14M |
timm-regnetx_040 | imagenet | 20M |
timm-regnetx_064 | imagenet | 24M |
timm-regnetx_080 | imagenet | 37M |
timm-regnetx_120 | imagenet | 43M |
timm-regnetx_160 | imagenet | 52M |
timm-regnetx_320 | imagenet | 105M |
timm-regnety_002 | imagenet | 2M |
timm-regnety_004 | imagenet | 3M |
timm-regnety_006 | imagenet | 5M |
timm-regnety_008 | imagenet | 5M |
timm-regnety_016 | imagenet | 10M |
timm-regnety_032 | imagenet | 17M |
timm-regnety_040 | imagenet | 19M |
timm-regnety_064 | imagenet | 29M |
timm-regnety_080 | imagenet | 37M |
timm-regnety_120 | imagenet | 49M |
timm-regnety_160 | imagenet | 80M |
timm-regnety_320 | imagenet | 141M |
GERNet
Encoder | Weights | Params, M |
---|---|---|
timm-gernet_s | imagenet | 6M |
timm-gernet_m | imagenet | 18M |
timm-gernet_l | imagenet | 28M |
SE-Net
Encoder | Weights | Params, M |
---|---|---|
senet154 | imagenet | 113M |
se_resnet50 | imagenet | 26M |
se_resnet101 | imagenet | 47M |
se_resnet152 | imagenet | 64M |
se_resnext50_32x4d | imagenet | 25M |
se_resnext101_32x4d | imagenet | 46M |
SK-ResNe(X)t
Encoder | Weights | Params, M |
---|---|---|
timm-skresnet18 | imagenet | 11M |
timm-skresnet34 | imagenet | 21M |
timm-skresnext50_32x4d | imagenet | 25M |
DenseNet
Encoder | Weights | Params, M |
---|---|---|
densenet121 | imagenet | 6M |
densenet169 | imagenet | 12M |
densenet201 | imagenet | 18M |
densenet161 | imagenet | 26M |
Inception
Encoder | Weights | Params, M |
---|---|---|
inceptionresnetv2 | imagenet / imagenet+background | 54M |
inceptionv4 | imagenet / imagenet+background | 41M |
xception | imagenet | 22M |
EfficientNet
Encoder | Weights | Params, M |
---|---|---|
efficientnet-b0 | imagenet | 4M |
efficientnet-b1 | imagenet | 6M |
efficientnet-b2 | imagenet | 7M |
efficientnet-b3 | imagenet | 10M |
efficientnet-b4 | imagenet | 17M |
efficientnet-b5 | imagenet | 28M |
efficientnet-b6 | imagenet | 40M |
efficientnet-b7 | imagenet | 63M |
timm-efficientnet-b0 | imagenet / advprop / noisy-student | 4M |
timm-efficientnet-b1 | imagenet / advprop / noisy-student | 6M |
timm-efficientnet-b2 | imagenet / advprop / noisy-student | 7M |
timm-efficientnet-b3 | imagenet / advprop / noisy-student | 10M |
timm-efficientnet-b4 | imagenet / advprop / noisy-student | 17M |
timm-efficientnet-b5 | imagenet / advprop / noisy-student | 28M |
timm-efficientnet-b6 | imagenet / advprop / noisy-student | 40M |
timm-efficientnet-b7 | imagenet / advprop / noisy-student | 63M |
timm-efficientnet-b8 | imagenet / advprop | 84M |
timm-efficientnet-l2 | noisy-student | 474M |
timm-efficientnet-lite0 | imagenet | 4M |
timm-efficientnet-lite1 | imagenet | 5M |
timm-efficientnet-lite2 | imagenet | 6M |
timm-efficientnet-lite3 | imagenet | 8M |
timm-efficientnet-lite4 | imagenet | 13M |
MobileNet
Encoder | Weights | Params, M |
---|---|---|
mobilenet_v2 | imagenet | 2M |
timm-mobilenetv3_large_075 | imagenet | 1.78M |
timm-mobilenetv3_large_100 | imagenet | 2.97M |
timm-mobilenetv3_large_minimal_100 | imagenet | 1.41M |
timm-mobilenetv3_small_075 | imagenet | 0.57M |
timm-mobilenetv3_small_100 | imagenet | 0.93M |
timm-mobilenetv3_small_minimal_100 | imagenet | 0.43M |
DPN
Encoder | Weights | Params, M |
---|---|---|
dpn68 | imagenet | 11M |
dpn68b | imagenet+5k | 11M |
dpn92 | imagenet+5k | 34M |
dpn98 | imagenet | 58M |
dpn107 | imagenet+5k | 84M |
dpn131 | imagenet | 76M |
VGG
Encoder | Weights | Params, M |
---|---|---|
vgg11 | imagenet | 9M |
vgg11_bn | imagenet | 9M |
vgg13 | imagenet | 9M |
vgg13_bn | imagenet | 9M |
vgg16 | imagenet | 14M |
vgg16_bn | imagenet | 14M |
vgg19 | imagenet | 20M |
vgg19_bn | imagenet | 20M |
* ssl
, swsl
- semi-supervised and weakly-supervised learning on ImageNet (repo).
Timm Encoders
Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported
- transformer models do not have
features_only
functionality implemented - some models do not have appropriate strides
Total number of supported encoders: 467
🔁 Models API
model.encoder
- pretrained backbone to extract features of different spatial resolutionmodel.decoder
- depends on models architecture (Unet
/Linknet
/PSPNet
/FPN
)model.segmentation_head
- last block to produce required number of mask channels (include also optional upsampling and activation)model.classification_head
- optional block which create classification head on top of encodermodel.forward(x)
- sequentially passx
through model`s encoder, decoder and segmentation head (and classification head if specified)
Input channels
Input channels parameter allows you to create models, which process tensors with arbitrary number of channels.
If you use pretrained weights from imagenet - weights of first convolution will be reused. For
1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be
populated with weights like new_weight[:, i] = pretrained_weight[:, i % 3]
and than scaled with new_weight * 3 / new_in_channels
.
model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
Auxiliary classification output
All models support aux_params
parameters, which is default set to None
.
If aux_params = None
then classification auxiliary output is not created, else
model produce not only mask
, but also label
output with shape NC
.
Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
configured by aux_params
as follows:
aux_params=dict(
pooling='avg', # one of 'avg', 'max'
dropout=0.5, # dropout ratio, default is None
activation='sigmoid', # activation function, default is None
classes=4, # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth
Depth parameter specify a number of downsampling operations in encoder, so you can make
your model lighter if specify smaller depth
.
model = smp.Unet('resnet34', encoder_depth=4)
🛠 Installation
PyPI version:
$ pip install segmentation-models-pytorch
Latest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
🏆 Competitions won with the library
Segmentation Models
package is widely used in the image segmentation competitions.
Here you can find competitions, names of the winners and links to their solutions.
🤝 Contributing
Install linting and formatting pre-commit hooks
pip install pre-commit black flake8
pre-commit install
Run tests
pytest -p no:cacheprovider
Run tests in docker
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
Generate table with encoders (in case you add a new encoder)
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
📝 Citing
@misc{Yakubovskiy:2019,
Author = {Pavel Yakubovskiy},
Title = {Segmentation Models Pytorch},
Year = {2020},
Publisher = {GitHub},
Journal = {GitHub repository},
Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}
🛡️ License
Project is distributed under MIT License
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file segmentation_models_pytorch_deepflash2-0.3.0.tar.gz
.
File metadata
- Download URL: segmentation_models_pytorch_deepflash2-0.3.0.tar.gz
- Upload date:
- Size: 61.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 27562ba46d5c58be2d7790365f7adbbfe4f53d265ce97edfbec4d7c5ee28c9f6 |
|
MD5 | b427cf76b4638697d6d116f805d02bda |
|
BLAKE2b-256 | 9801addb987459bc18443f319a32535322e35c69fab85dd405bd217335fef6b9 |
File details
Details for the file segmentation_models_pytorch_deepflash2-0.3.0-py3-none-any.whl
.
File metadata
- Download URL: segmentation_models_pytorch_deepflash2-0.3.0-py3-none-any.whl
- Upload date:
- Size: 186.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 061a7b84886fa23844eb8bce54c97472e872cfd6d02acea3319a1d74f89479ca |
|
MD5 | d2c754f251c1e9ecb1da3fda28bac845 |
|
BLAKE2b-256 | aedc4e51471765a1b103dcf0e9538780ffe6c03678ecf8bda5755304dc6c1e80 |