Skip to main content

A small example package

Project description

Pretrained Backbones with UNet

A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.

Generic badge PyPI PyPI - Downloads
PyTorch - Version Python - Version

Overview

This is a simple package for semantic segmentation with UNet and pretrained backbones. This package utilizes the timm models for the pre-trained encoders.

When dealing with relatively limited datasets, initializing a model using pre-trained weights from a large dataset can be an excellent choice for ensuring successful network training. By utilizing state-of-the-art models, such as ConvNeXt, as an encoder, you can effortlessly solve the problem at hand while achieving optimal performance in this context.

The primary characteristics of this library are as follows:

  • 430 pre-trained backbone networks are available for the UNet semantic segmentation model.

  • Supports backbone networks such as ConvNext, ResNet, EfficientNet, DenseNet, RegNet, and VGG... which are popular and SOTA performers, for the UNet model.

  • It is possible to adjust which layers of the backbone of the model are trainable parametrically.

  • It includes a DataSet class for binary and multi-class semantic segmentation.

  • And it comes with a pre-built rapid custom training class.

Installation

Pypi version:

pip install pretrained-backbones-unet

Source code version:

pip install git+https://github.com/mberkay0/pretrained-backbones-unet

Usage

from backbones_unet.model.unet import Unet
from backbones_unet.utils.dataset import SemanticSegmentationDataset
from backbones_unet.model.losses import DiceLoss
from backbones_unet.utils.trainer import Trainer

# create a torch.utils.data.Dataset/DataLoader
train_img_path = 'example_data/train/images' 
train_mask_path = 'example_data/train/masks'

val_img_path = 'example_data/val/images' 
val_mask_path = 'example_data/val/masks'

train_dataset = SemanticSegmentationDataset(train_img_path, train_mask_path)
val_dataset = SemanticSegmentationDataset(val_img_path, val_mask_path)

train_loader = DataLoader(train_dataset, batch_size=2)
val_loader = DataLoader(val_dataset, batch_size=2)

model = Unet(
    backbone='convnext_base', # backbone network name
    in_channels=3,            # input channels (1 for gray-scale images, 3 for RGB, etc.)
    num_classes=1,            # output channels (number of classes in your dataset)
)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, 1e-4) 

trainer = Trainer(
    model,                    # UNet model with pretrained backbone
    criterion=DiceLoss(),     # loss function for model convergence
    optimizer,                # optimizer for regularization
    10                        # number of epochs for model training
)

trainer.fit(train_loader, val_loader)

Available Pretrained Backbones

import backbones_unet

print(backbones_unet.__available_models__)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pretrained-backbones-unet-0.0.1.tar.gz (16.3 kB view details)

Uploaded Source

Built Distribution

pretrained_backbones_unet-0.0.1-py3-none-any.whl (15.7 kB view details)

Uploaded Python 3

File details

Details for the file pretrained-backbones-unet-0.0.1.tar.gz.

File metadata

File hashes

Hashes for pretrained-backbones-unet-0.0.1.tar.gz
Algorithm Hash digest
SHA256 cfc00f7ac5040667da43672dec55bb4d158afa8f85c696ab0daf08642feef239
MD5 d50b2172d3284444c3fef7ed51425fae
BLAKE2b-256 54a53587d59abd6375a79ea0627a169c0ff79de2685c7e6c087601d65ab5b1d3

See more details on using hashes here.

File details

Details for the file pretrained_backbones_unet-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for pretrained_backbones_unet-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 27d3438b935736e95ea930c9c8c05167c2304d52c1096ce0022b327c025cd6f1
MD5 d01f924a7bd3bcffd31df66ce4867d26
BLAKE2b-256 dc85b0b91ef867b431cb7c778496e612190b0f83acae543006967855fe7e6d4a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page