PyTorch Deep Extreme Cut library
Project description
PyTorch implementation of DEXTR
An implementation of DEXTR. The original implementation can be found at https://github.com/scaelles/DEXTR-PyTorch.
This implementation is intended for use as a library.
Installation
> pip install dextr
Python Inference API
See demo.py for an example of using the dextr inference API.
We have trained a ResNet-101 based U-Net DEXTR model on the Pascal VOC 2012 training set. You can download it here.
You can load this model -- downloading it automatically -- like so:
from dextr.model import DextrModel
# Load the model (automatically downloads if necessary)
# You can also provide a `map_location` paramter to load it onto a specific device
model = DextrModel.pascalvoc_resunet101()
Alternatively you can load a model that you have trained yourself from a file:
MODEL_PATH = '...'
dextr_model = torch.load(MODEL_PATH, map_location='cuda:0')
Use the predict method to predict a mask for an object in an image, identified by its extreme points:
mask = dextr_model.predict([image], [extreme_points])[0]
You can perform inference on multiple images with one call.
The DextrModel.predict method takes a list
of images and extreme points as either a list of (4, [y, x]) NumPy arrays or
one (N, 4, [y, x]) shaped NumPy array.
The images that you use as input can take the form of either NumPy arrays or PIL Images. Each image should have a corresponding list of four extreme points. It returns a list of masks; each mask is the same size as the corresponding input image:
Training using the command line train_dextr.py program
Train a DEXTR network using the Pascal VOC dataset
This will train a DEXTR model using a U-Net with a ResNet-101 based encoder. It should take several hours on an nVidia 1080Ti GPU.
- Download the Pascal VOC 2012 dataset development kit
- Create a file called
dextr.cfgwith the following contents:
[paths]
pascal_voc=<path to VOC2012 diretory>
- Train the DEXTR model by running:
> python train_dextr.py pascal_resunet101 --dataset=pascal_voc --arch=resunet101
The name pascal_resunet101 is the name of the job; STDOUT will be logged to logs/log_pascal_resunet101.txt and the model
file will be saved to checkpoints/pascal_resunet101.pth. You can give the job any name you like.
Fine tuning a DEXTR network using a custom data set
There are two types of data set you can use:
- Each input image has a corresponding label image, where label images have an integer pixel type such that each pixel gives the index of the object that covers it, or 0 for background. The Pascal VOC dataset is arranged in this way.
- Each input image has a corresponding set of mask images that form a stack. Each mask image is an 8-bit greyscale image that corresponds to an object/instance and identifies the pixels covered by it.
Please arrange your custom data set so that the image file names (excluding extension) match or are a prefix
to the label/mask image file names. E.g. the image img0.jpg will match the label file img0.png or img0_labels.png.
For mask stack datasets img0.jpg would match to the mask images img0_mask0.png, ... img0_maskN.png.
The images and labels can live in separate directories; they are matched by filename only.
In these examples, we assume that you have downloaded the pre-trained DEXTR model linked above.
Training using a label image data set
> python train_dextr.py my_model_from_labels --dataset=custom_label --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/labels/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth
The input and label images are given to the --train_image_pat and --train_target_pat options.
You can specify validation images using the --val_image_pat and --val_target_pat options in a similar way.
--load_model=dextr_pascalvoc_resunet101-a2d81727.pth indicates that we should start by loading the
model trained on Pascal VOC above and fine-tune it, rather than starting from an ImageNet classifier.
You can specify that the label index 255 should be ignore by adding --label_ignore_index=255.
You could train using the entire (train and validation) Pascal VOC data set using:
> python train_dextr.py my_model_from_pascal --dataset=custom_label --train_image_pat=/pascal/VOC2012/JPEGImages/*.jpg --train_target_pat=/pascal/VOC2012/SegmentationObjects/*.png --label_ignore_index=255 --arch=resunet101
Training using a mask stack data set
> python train_dextr.py my_model_from_masks --dataset=custom_mask --train_image_pat=/mydataset/train/input/*.jpg --train_target_pat=/mydataset/train/masks/*.png --arch=resunet101 --load_model=dextr_pascalvoc_resunet101-a2d81727.pth
Python training API
The training_loop function within the dextr.model module provides a simple training loop that can be used
for training or fine-tuning models. See train_dextr.py for usage.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file dextr-0.1.2.tar.gz.
File metadata
- Download URL: dextr-0.1.2.tar.gz
- Upload date:
- Size: 22.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
5e14716b26ffa64d16fdbb1b5880f9e658c58752e0d1952a022eb897c62bda7a
|
|
| MD5 |
ed72205592a563e98544e6327f180b02
|
|
| BLAKE2b-256 |
33dbfe3ee27125b7b68cc46f523c472fce33456d45842da018bd02967253ae09
|
File details
Details for the file dextr-0.1.2-py3-none-any.whl.
File metadata
- Download URL: dextr-0.1.2-py3-none-any.whl
- Upload date:
- Size: 32.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.1.1.post20200323 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2d3f857020723affb921396d7cfd88d2d3fb7581b74a21ada5841922c2031e1e
|
|
| MD5 |
fec92801fdce45e7dce3abfd538efe25
|
|
| BLAKE2b-256 |
7883a05884955bdaaa4aaa117656b1a627b50ae409e373b130fe11670a277e8e
|