Skip to main content

Image segmentation models training of popular architectures.

Project description

pytorch_segmentation_models_trainer

Torch Pytorch Lightning Hydra Segmentation Models Python application Upload Python Package PyPI Publish Docker image maintainer DOI codecov Open in Visual Studio Code pre-commit.ci status

Framework based on Pytorch, Pytorch Lightning, segmentation_models.pytorch and hydra to train semantic segmentation models using yaml config files as follows:

model:
  _target_: segmentation_models_pytorch.Unet
  encoder_name: resnet34
  encoder_weights: imagenet
  in_channels: 3
  classes: 1

loss:
  _target_: segmentation_models_pytorch.utils.losses.DiceLoss

optimizer:
  _target_: torch.optim.AdamW
  lr: 0.001
  weight_decay: 1e-4

hyperparameters:
  batch_size: 1
  epochs: 2
  max_lr: 0.1

pl_trainer:
  max_epochs: ${hyperparameters.batch_size}
  gpus: 0

train_dataset:
  _target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
  input_csv_path: /path/to/input.csv
  data_loader:
    shuffle: True
    num_workers: 1
    pin_memory: True
    drop_last: True
    prefetch_factor: 1
  augmentation_list:
    - _target_: albumentations.HueSaturationValue
      always_apply: false
      hue_shift_limit: 0.2
      p: 0.5
    - _target_: albumentations.RandomBrightnessContrast
      brightness_limit: 0.2
      contrast_limit: 0.2
      p: 0.5
    - _target_: albumentations.RandomCrop
      always_apply: true
      height: 256
      width: 256
      p: 1.0
    - _target_: albumentations.Flip
      always_apply: true
    - _target_: albumentations.Normalize
      p: 1.0
    - _target_: albumentations.pytorch.transforms.ToTensorV2
      always_apply: true

val_dataset:
  _target_: pytorch_segmentation_models_trainer.dataset_loader.dataset.SegmentationDataset
  input_csv_path: /path/to/input.csv
  data_loader:
    shuffle: True
    num_workers: 1
    pin_memory: True
    drop_last: True
    prefetch_factor: 1
  augmentation_list:
    - _target_: albumentations.Resize
      always_apply: true
      height: 256
      width: 256
      p: 1.0
    - _target_: albumentations.Normalize
      p: 1.0
    - _target_: albumentations.pytorch.transforms.ToTensorV2
      always_apply: true

To train a model with configuration path /path/to/config/folder and name test.yaml:

pytorch-smt --config-dir /path/to/config/folder --config-name test +mode=train

The mode can be stored in configuration yaml as well. In this case, do not pass the +mode= argument. If the mode is stored in the yaml and you want to overwrite the value, do not use the + clause, just mode= .

This module suports hydra features such as configuration composition. For further information, please visit https://hydra.cc/docs/intro

Install

If you are not using docker and if you want to enable gpu acceleration, before installing this package, you should install pytorch_scatter as instructed in https://github.com/rusty1s/pytorch_scatter

After installing pytorch_scatter, just do

pip install pytorch_segmentation_models_trainer

We have a docker container in which all dependencies are installed and ready for gpu usage. You can pull the image from dockerhub:

docker pull phborba/pytorch_segmentation_models_trainer:latest

Citing:


@software{philipe_borba_2021_5115127,
  author       = {Philipe Borba},
  title        = {{phborba/pytorch\_segmentation\_models\_trainer:
                   Version 0.8.0}},
  month        = jul,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {v0.8.0},
  doi          = {10.5281/zenodo.5115127},
  url          = {https://doi.org/10.5281/zenodo.5115127}
}


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_segmentation_models_trainer-0.12.1.tar.gz (104.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file pytorch_segmentation_models_trainer-0.12.1.tar.gz.

File metadata

  • Download URL: pytorch_segmentation_models_trainer-0.12.1.tar.gz
  • Upload date:
  • Size: 104.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for pytorch_segmentation_models_trainer-0.12.1.tar.gz
Algorithm Hash digest
SHA256 474ec1fee739d7f261e67bea14077f5dfb102809ba54b62a4019695719fc81a6
MD5 e61497b1621fe4c36921fa59cb5f8b8e
BLAKE2b-256 4fc66afd557ebb88616ec9d0e21dcbab572e6e0ab2134dcfcf352def628b7586

See more details on using hashes here.

File details

Details for the file pytorch_segmentation_models_trainer-0.12.1-py3-none-any.whl.

File metadata

  • Download URL: pytorch_segmentation_models_trainer-0.12.1-py3-none-any.whl
  • Upload date:
  • Size: 151.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for pytorch_segmentation_models_trainer-0.12.1-py3-none-any.whl
Algorithm Hash digest
SHA256 d65e81db33b6c90b13c207e4d0f661d848b93fe6977731880b54fb85ffb7528c
MD5 796bf9ae5c6c821424d229428b8b4259
BLAKE2b-256 2e560e7135c881eadd059ef015bcd37bf4317e34dc3b12dbd44e5b8bd367fe63

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page