Skip to main content

Segmentation Networks without a Backbone

Project description

Python 3.9 GitHub

🔥 OctoPyTorch: Segmentation Neural Networks 🔥

No backbones needed, with a focus on medical images

Implementation of the segmentation neural networks for PyTorch with new features such as:

  • 🐟 No backbones, the architectures remain simple to understand and train.
  • 💾 Memory-efficient version (trade-off between memory and speed).
  • 🖼 Works with any input size (not only powers of 2 anymore).
  • 👁 Different types of upsampling (transposed convolution, upsampling and pixel shuffle).
  • 🏊‍♀️ Different types of pooling (max-pooling, avg-pooling, blur-pooling).
  • 🏗 The depth and width of the models are fully configurable.
  • 🔬 Early-transition can be enabled when the input images are big.
  • 👸🏼 The activation functions of all layers can be modified to something trendier.

For the Tiramisu architecture:

Roadmap

Support for the following neural networks:

Getting Started

The package can be installed from the repository with:

> pip3 install octopytorch

You can try the model in Python with:

from functools import partial
import torch
from torch import nn
from octopytorch import models, DEFAULT_MODULE_BANK, ModuleName

module_bank = DEFAULT_MODULE_BANK.copy()
# Dropout
module_bank[ModuleName.DROPOUT] = partial(nn.Dropout2d, p=0.2, inplace=True)
# Every activation in the model is going to be a GELU (Gaussian Error Linear 
# Units function). GELU(x) = x * Φ(x)
# See: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
module_bank[ModuleName.ACTIVATION] = nn.GELU
# Example for segmentation:
module_bank[ModuleName.ACTIVATION_FINAL] = partial(nn.LogSoftmax, dim=1)
# Example for regression (default):
#module_bank[ModuleName.ACTIVATION_FINAL] = nn.Identity

model = models.Tiramisu(
    in_channels = 3,          # RGB images
    out_channels = 5,         # 5-channel output (5 classes)
    init_conv_filters = 48,   # Number of channels outputted by the 1st convolution
    structure = (
        [4, 4, 4, 4, 4],      # Down blocks
        4,                    # bottleneck layers
        [4, 4, 4, 4, 4],      # Up blocks
    ),
    growth_rate = 12,         # Growth rate of the DenseLayers
    compression = 1.0,        # No compression
    early_transition = False, # No early transition
    include_top = True,       # Includes last layer and activation
    checkpoint = False,       # No memory checkpointing
    module_bank = module_bank # Modules to use
)

# Initializes all the convolutional kernel weights.
model.initialize_kernels(nn.init.kaiming_uniform_, conv=True)
# Shows some information about the model.
model.summary()

This example tiramisu network has a depth of len(down_blocks) = 5, meaning that the input images should be at least 32x32 pixels (i.e. 2^5=32).

Documentation

The parameters of the constructor are explained as following:

  • in_channels: The number of channels of the input image (e.g. 1 for grayscale, 3 for RGB).
  • out_channels: The number of output channels (e.g. C for C classes).
  • init_conv_filters: The number of filters in the very first convolution.
  • structure: Divided in three parts (down blocks, bottleneck and up blocks) which describe the depth of the neural network (how many levels there are) and how many DenseLayers each of those levels have.
  • growth_rate: Describes the size of each convolution in the DenseLayers. At each conv. the DenseLayer grows by this many channels.
  • compression: The compression of the DenseLayers to reduce the memory footprint and computational complexity of the model.
  • early_transition: Optimization where the input is downscaled by a factor of two after the first layer by using a down-transition (without skip-connection) early on.
  • include_top: Including the top layer, with the last convolution and activation (True) or returns the embeddings for each pixel.
  • checkpoint: Activates memory checkpointing, a memory efficient version of the Tiramisu. See: https://arxiv.org/pdf/1707.06990.pdf
  • module_bank: The bank of layers the Tiramisu uses to build itself. See next subsection for details.

Module bank

The Tiramisu base layers (e.g. Conv2D, activation functions, etc.) can be set to different types of layers. This was introduced to wrap many arguments of the main class under the same object and increase the flexibility to change layers.

The layers that can be redefined are:

  • CONV: Convolution operations in the full model. Change with care.
  • CONV_INIT: Initial (1st) convolution operation. Note: Kernel size must be provided.
  • CONV_FINAL: Final convolution. Will be set to a 1x1 kernel and reduce output to C classes.
  • BATCHNORM: Batch normalization in the full model.
  • POOLING: Pooling operation. Note: must reduce input size by a factor of two. If the size is odd, round up to the closest integer.
  • DROPOUT: Dropout. The p value must be provided through partial.
  • UPSAMPLE: Upsampling operation (must be by a factor of two)
  • ACTIVATION: Activation function to use everywhere
  • ACTIVATION_FINAL: Act. function at the last layer (e.g. softmax, nn.Identity)

Notes:

  • For pooling common options are nn.MaxPool2d, nn.AvgPool2d, or even tiramisu.layers.blurpool.BlurPool2d.
  • For upsampling, there are some presets: UPSAMPLE_NEAREST (default), UPSAMPLE_PIXELSHUFFLE, UPSAMPLE_TRANSPOSE (known to produce artifacts).
  • The layers can be set to nn.Identity to be bypassed (e.g. if one wants to remove the dropout layer, or the final activation).
  • The partial function can prefill some of the arguments to be used in the model.

Tips and tricks

  • Make sure the features you are interested in fit approximately the perceptive field. For instance, if you have an object that measures 50 pixels, you need at approx. 6 levels of resolution in down/up blocks (since 2^6=64 > 50). Or use early transition, which down samples the input by two.
  • If you need to reduce the memory footprint, trying out the efficient version, enabling the early transition is a great way to start. Then, using compression, reducing the growth rate and finally the number of dense blocks in the down/up blocks.
  • Use upsampling instead of transposed convolution, seriously. Transposed convolutions are hard to manage and may create a lot of gridding artefacts.
  • Use blurpooling if you want the neural network to be shift-invariant (good accuracy even when shifting the input).
  • The model creates border artifacts at the edge, which can be mitigated by changing the padding_mode argument of the Conv2d in the module bank. For instance, using "reflect" instead of "zeros" will create a smooth continuation in the boundaries instead of an edge.

Built With

  • Pytorch - Version >=1.4.0 (for memory efficient version)

Contributing

See also the list of contributors who participated in this project. For contributing, make sure the code passes the checks of Pylama, Bandit and Mypy. Additionally, the code is formatted with Black.

License

This project is licensed under the MIT License - see the LICENSE.md file for details.

Acknowledgments

Many thanks to @RaphaelaHeil for her much appreciated advices on best practices.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

octopytorch-0.2.tar.gz (18.8 kB view details)

Uploaded Source

File details

Details for the file octopytorch-0.2.tar.gz.

File metadata

  • Download URL: octopytorch-0.2.tar.gz
  • Upload date:
  • Size: 18.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.6.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.9.1

File hashes

Hashes for octopytorch-0.2.tar.gz
Algorithm Hash digest
SHA256 386911a28534f0caba67edff0f1f1ee1d9256a824a13d8ee3d45dd2a41fae807
MD5 c11d0b966c04eada518854d54b3bb8f9
BLAKE2b-256 e4bc026abe98b616c7f8f6da6d228fc6f6b1781284f7f0d9b60dd5d8c8f0e560

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page