Skip to main content

PyTorch Deep Extreme Cut library

Project description

PyTorch implementation of DEXTR

An implementation of DEXTR. The original implementation can be found at https://github.com/scaelles/DEXTR-PyTorch.

This implementation is intended for use as a library.

Installation

> pip install dextr

Python Inference API

See demo.py for an example of using the dextr inference API.

We have trained a ResNet-101 based U-Net DEXTR model on the Pascal VOC 2012 training set. You can download it here.

You can load this model -- downloading it automatically -- like so:

from dextr.model import DextrModel

# Load the model (automatically downloads if necessary)
# You can also provide a `map_location` paramter to load it onto a specific device
model = DextrModel.pascalvoc_resunet101()

Alternatively you can load a model that you have trained yourself from a file:

MODEL_PATH = '...'
dextr_model = torch.load(MODEL_PATH, map_location='cuda:0')

Use the predict method to predict a mask for an object in an image, identified by its extreme points:

mask = dextr_model.predict([image], [extreme_points])[0]

You can perform inference on multiple images with one call. The DextrModel.predict method takes a list of images and extreme points as either a list of (4, [y, x]) NumPy arrays or one (N, 4, [y, x]) shaped NumPy array.

The images that you use as input can take the form of either NumPy arrays or PIL Images. Each image should have a corresponding list of four extreme points. It returns a list of masks; each mask is the same size as the corresponding input image:

Training using the command line train_dextr.py program

Train a DEXTR network using the Pascal VOC dataset

This will train a DEXTR model using a U-Net with a ResNet-101 based encoder. It should take several hours on an nVidia 1080Ti GPU.

  • Download the Pascal VOC 2012 dataset development kit
  • Create a file called dextr.cfg with the following contents:
[paths]
pascal_voc=<path to VOC2012 diretory>
  • Train the DEXTR model by running:

> python train_dextr.py pascal_resunet101 --dataset=pascal_voc --arch=resunet101

The name pascal_resunet101 is the name of the job; STDOUT will be logged to logs/log_pascal_resunet101.txt and the model file will be saved to checkpoints/pascal_resunet101.pth. You can give the job any name you like.

Fine tuning a DEXTR network using a custom data set

There are two types of data set you can use:

  1. Each input image has a corresponding label image, where label images have an integer pixel type such that each pixel gives the index of the object that covers it, or 0 for background. The Pascal VOC dataset is arranged in this way.
  2. Each input image has a corresponding set of mask images that form a stack. Each mask image is an 8-bit greyscale image that corresponds to an object/instance and identifies the pixels covered by it.

Please arrange your custom data set so that the image file names (excluding extension) match or are a prefix to the label/mask image file names. E.g. the image img0.jpg will match the label file img0.png or img0_labels.png. For mask stack datasets img0.jpg would match to the mask images img0_mask0.png, ... img0_maskN.png. The images and labels can live in separate directories; they are matched by filename only.

In these examples, we assume that you have downloaded the pre-trained DEXTR model linked above.

Training using a label image data set

> python train_dextr.py my_model_from_labels --dataset=custom_label --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/labels/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth

The input and label images are given to the --train_image_pat and --train_target_pat options. You can specify validation images using the --val_image_pat and --val_target_pat options in a similar way.

--load_model=dextr_pascalvoc_resunet101-a2d81727.pth indicates that we should start by loading the model trained on Pascal VOC above and fine-tune it, rather than starting from an ImageNet classifier.

You can specify that the label index 255 should be ignore by adding --label_ignore_index=255.

You could train using the entire (train and validation) Pascal VOC data set using:

> python train_dextr.py my_model_from_pascal --dataset=custom_label --train_image_pat=/pascal/VOC2012/JPEGImages/*.jpg --train_target_pat=/pascal/VOC2012/SegmentationObjects/*.png --label_ignore_index=255 --arch=resunet101

Training using a mask stack data set

> python train_dextr.py my_model_from_masks --dataset=custom_mask --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/masks/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth

Python training API

The training_loop function within the dextr.model module provides a simple training loop that can be used for training or fine-tuning models. See train_dextr.py for usage.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dextr-0.1.2.tar.gz (22.7 kB view details)

Uploaded Source

Built Distribution

dextr-0.1.2-py3-none-any.whl (32.2 kB view details)

Uploaded Python 3

File details

Details for the file dextr-0.1.2.tar.gz.

File metadata

  • Download URL: dextr-0.1.2.tar.gz
  • Upload date:
  • Size: 22.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7

File hashes

Hashes for dextr-0.1.2.tar.gz
Algorithm Hash digest
SHA256 5e14716b26ffa64d16fdbb1b5880f9e658c58752e0d1952a022eb897c62bda7a
MD5 ed72205592a563e98544e6327f180b02
BLAKE2b-256 33dbfe3ee27125b7b68cc46f523c472fce33456d45842da018bd02967253ae09

See more details on using hashes here.

File details

Details for the file dextr-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: dextr-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 32.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7

File hashes

Hashes for dextr-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 2d3f857020723affb921396d7cfd88d2d3fb7581b74a21ada5841922c2031e1e
MD5 fec92801fdce45e7dce3abfd538efe25
BLAKE2b-256 7883a05884955bdaaa4aaa117656b1a627b50ae409e373b130fe11670a277e8e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page