Skip to main content

Universal implementation of the UNet architecture for image segmentation.

Project description

UNET SEGMENTATION PYTORCH

Installation

pip install segment-torch

Usage

from segment_torch.unet import UNet
from torch import nn

device = "cuda"

config = dict(
    in_channels=3,
    out_channels=1,
    hiddens=[4, 8, 16, 32],
    dropouts=[0, 0.15, 0.15, 0.15],  # hiddens
    maxpools=2,  # hiddens - 1
    kernel_sizes=3,  # 2*hiddens + 3*hiddens + 2
    paddings='same',  # 2*hiddens + 3*hiddens + 2
    strides=1,  # 2*hiddens + 3*hiddens
    dilation=1,
    criterion=nn.BCELoss(),
    output_activation=nn.Sigmoid(),
    activation=nn.ReLU(),
    dimensions=2,
    device=device
)
unet = UNet(**config)

Different ways to define configs

# 0. None: default values are used
kernel_sizes=None

# 1. Single value or tuple: all layers have the same value
kernel_sizes = 3 
kernel_sizes = (3, 3)

# 2. Lists of values
encooder_kernel_sizes = [3, 3, 3, 3]
decoder_kernel_sizes = [3, 3, 3, 3, 3]
kernel_sizes = [encooder_kernel_sizes, decoder_kernel_sizes]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

segment_torch-0.0.10.tar.gz (8.5 kB view hashes)

Uploaded Source

Built Distribution

segment_torch-0.0.10-py3-none-any.whl (10.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page