PyTorch Deep Extreme Cut library
Project description
PyTorch implementation of DEXTR
An implementation of DEXTR. The original implementation can be found at https://github.com/scaelles/DEXTR-PyTorch.
This implementation is intended for use as a library.
Installation
> pip install dextr
Python Inference API
See demo.py
for an example of using the dextr
inference API.
We have trained a ResNet-101 based U-Net DEXTR model on the Pascal VOC 2012 training set. You can download it here.
You can load this model -- downloading it automatically -- like so:
from dextr.model import DextrModel
# Load the model (automatically downloads if necessary)
# You can also provide a `map_location` paramter to load it onto a specific device
model = DextrModel.pascalvoc_resunet101()
Alternatively you can load a model that you have trained yourself from a file:
MODEL_PATH = '...'
dextr_model = torch.load(MODEL_PATH, map_location='cuda:0')
Use the predict
method to predict a mask for an object in an image, identified by its extreme points:
mask = dextr_model.predict([image], [extreme_points])[0]
You can perform inference on multiple images with one call.
The DextrModel.predict
method takes a list
of images and extreme points as either a list of (4, [y, x])
NumPy arrays or
one (N, 4, [y, x])
shaped NumPy array.
The images that you use as input can take the form of either NumPy arrays or PIL Images. Each image should have a corresponding list of four extreme points. It returns a list of masks; each mask is the same size as the corresponding input image:
Training using the command line train_dextr.py
program
Train a DEXTR network using the Pascal VOC dataset
This will train a DEXTR model using a U-Net with a ResNet-101 based encoder. It should take several hours on an nVidia 1080Ti GPU.
- Download the Pascal VOC 2012 dataset development kit
- Create a file called
dextr.cfg
with the following contents:
[paths]
pascal_voc=<path to VOC2012 diretory>
- Train the DEXTR model by running:
> python train_dextr.py pascal_resunet101 --dataset=pascal_voc --arch=resunet101
The name pascal_resunet101
is the name of the job; STDOUT will be logged to logs/log_pascal_resunet101.txt
and the model
file will be saved to checkpoints/pascal_resunet101.pth
. You can give the job any name you like.
Fine tuning a DEXTR network using a custom data set
There are two types of data set you can use:
- Each input image has a corresponding label image, where label images have an integer pixel type such that each pixel gives the index of the object that covers it, or 0 for background. The Pascal VOC dataset is arranged in this way.
- Each input image has a corresponding set of mask images that form a stack. Each mask image is an 8-bit greyscale image that corresponds to an object/instance and identifies the pixels covered by it.
Please arrange your custom data set so that the image file names (excluding extension) match or are a prefix
to the label/mask image file names. E.g. the image img0.jpg
will match the label file img0.png
or img0_labels.png
.
For mask stack datasets img0.jpg
would match to the mask images img0_mask0.png
, ... img0_maskN.png
.
The images and labels can live in separate directories; they are matched by filename only.
In these examples, we assume that you have downloaded the pre-trained DEXTR model linked above.
Training using a label image data set
> python train_dextr.py my_model_from_labels --dataset=custom_label --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/labels/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth
The input and label images are given to the --train_image_pat
and --train_target_pat
options.
You can specify validation images using the --val_image_pat
and --val_target_pat
options in a similar way.
--load_model=dextr_pascalvoc_resunet101-a2d81727.pth
indicates that we should start by loading the
model trained on Pascal VOC above and fine-tune it, rather than starting from an ImageNet classifier.
You can specify that the label index 255 should be ignore by adding --label_ignore_index=255
.
You could train using the entire (train and validation) Pascal VOC data set using:
> python train_dextr.py my_model_from_pascal --dataset=custom_label --train_image_pat=/pascal/VOC2012/JPEGImages/*.jpg --train_target_pat=/pascal/VOC2012/SegmentationObjects/*.png --label_ignore_index=255 --arch=resunet101
Training using a mask stack data set
> python train_dextr.py my_model_from_masks --dataset=custom_mask --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/masks/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth
Python training API
The training_loop
function within the dextr.model
module provides a simple training loop that can be used
for training or fine-tuning models. See train_dextr.py
for usage.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file dextr-0.1.2.tar.gz
.
File metadata
- Download URL: dextr-0.1.2.tar.gz
- Upload date:
- Size: 22.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5e14716b26ffa64d16fdbb1b5880f9e658c58752e0d1952a022eb897c62bda7a |
|
MD5 | ed72205592a563e98544e6327f180b02 |
|
BLAKE2b-256 | 33dbfe3ee27125b7b68cc46f523c472fce33456d45842da018bd02967253ae09 |
File details
Details for the file dextr-0.1.2-py3-none-any.whl
.
File metadata
- Download URL: dextr-0.1.2-py3-none-any.whl
- Upload date:
- Size: 32.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2d3f857020723affb921396d7cfd88d2d3fb7581b74a21ada5841922c2031e1e |
|
MD5 | fec92801fdce45e7dce3abfd538efe25 |
|
BLAKE2b-256 | 7883a05884955bdaaa4aaa117656b1a627b50ae409e373b130fe11670a277e8e |