Deep Leaarning segmentation architectures for PyTorch and FastAI
Project description
SemTorch
This repository contains different deep learning architectures definitions that can be applied to image segmentation.
All the architectures are implemented in PyTorch and can been trained easily with FastAI 2.
In Deep-Tumour-Spheroid repository can be found and example of how to apply it with a custom dataset, in that case brain tumours images are used.
These architectures are classified as:
- Semantic Segmentation: each pixel of an image is linked to a class label.
- Instance Segmentation: is similar to semantic segmentation, but goes a bit deeper, it identifies , for each pixel, the object instance it belongs to.
- Salient Object Detection (Binary clases only): detection of the most noticeable/important object in an image.
🚀 Getting Started
To start using this package, install it using pip
:
For example, for installing it in Ubuntu use:
pip3 install SemTorch
👩💻 Usage
This package creates an abstract API to access a segmentation model of different architectures. This method returns a FastAI 2 learner that can be combined with all the fastai's functionalities.
# SemTorch
from semtorch import get_segmentation_learner
learn = get_segmentation_learner(dls=dls, number_classes=2, segmentation_type="Semantic Segmentation",
architecture_name="deeplabv3+", backbone_name="resnet50",
metrics=[tumour, Dice(), JaccardCoeff()],wd=1e-2,
splitter=segmentron_splitter).to_fp16()
You can find a deeper example in Deep-Tumour-Spheroid repository, in this repo the package is used for the segmentation of brain tumours.
def get_segmentation_learner(dls, number_classes, segmentation_type, architecture_name, backbone_name,
loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params,
cbs=None, pretrained=True, normalize=True, image_size=None, metrics=None,
path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
moms=(0.95,0.85,0.95)):
This function return a learner for the provided architecture and backbone
Parameters:
- dls (DataLoader): the dataloader to use with the learner
- number_classes (int): the number of clases in the project. It should be >=2
- segmentation_type (str): just
Semantic Segmentation
accepted for now - architecture_name (str): name of the architecture. The following ones are supported:
unet
,deeplabv3+
,hrnet
,maskrcnn
andu2^net
- backbone_name (str): name of the backbone
- loss_func (): loss function.
- opt_func (): opt function.
- lr (): learning rates
- splitter (): splitter function for freazing the learner
- cbs (List[cb]): list of callbacks
- pretrained (bool): it defines if a trained backbone is needed
- normalize (bool): if normalization is applied
- image_size (int): REQUIRED for MaskRCNN. It indicates the desired size of the image.
- metrics (List[metric]): list of metrics
- path (): path parameter
- model_dir (str): the path in which save models
- wd (float): wieght decay
- wd_bn_bias (bool):
- train_bn (bool):
- moms (Tuple(float)): tuple of different momentuns
Returns:
- learner: value containing the learner object
Supported configs
Architecture | supported config | backbones |
---|---|---|
unet | Semantic Segmentation ,binary Semantic Segmentation ,multiple |
resnet18 , resnet34 , resnet50 , resnet101 , resnet152 , xresnet18 , xresnet34 , xresnet50 , xresnet101 , xresnet152 , squeezenet1_0 , squeezenet1_1 , densenet121 , densenet169 , densenet201 , densenet161 , vgg11_bn , vgg13_bn , vgg16_bn , vgg19_bn , alexnet |
deeplabv3+ | Semantic Segmentation ,binary Semantic Segmentation ,multiple |
resnet18 , resnet34 , resnet50 , resnet101 , resnet152 , resnet50c , resnet101c , resnet152c , xception65 , mobilenet_v2 |
hrnet | Semantic Segmentation ,binary Semantic Segmentation ,multiple |
hrnet_w18_small_model_v1 , hrnet_w18_small_model_v2 , hrnet_w18 , hrnet_w30 , hrnet_w32 , hrnet_w48 |
maskrcnn | Semantic Segmentation ,binary |
resnet50 |
u2^net | Semantic Segmentation ,binary |
small , normal |
📩 Contact
📧 dvdlacallecastillo@gmail.com
💼 Linkedin David Lacalle Castillo
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file SemTorch-0.1.1.tar.gz
.
File metadata
- Download URL: SemTorch-0.1.1.tar.gz
- Upload date:
- Size: 41.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.21.0 setuptools/49.6.0 requests-toolbelt/0.9.1 tqdm/4.31.1 CPython/3.7.3
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f0814cd387fc03581fbff4a63b63a50dc889e19780b52273d5a734328c5f1647 |
|
MD5 | f4ae1f6ec463150b9565f3dbed53d42a |
|
BLAKE2b-256 | 7859e41afbd4cf5f8bcbf6d0b8117f60586e87f9053610fcb262d2af950ab7b4 |