Images test time augmentation with Paddle.
Project description
Patta
Image Test Time Augmentation with Paddle2.0!
Input
| # input batch of images
/ / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)
| | | | | | | # pass augmented batches through model
| | | | | | | # reverse transformations for each batch of masks/labels
\ \ \ / / / # merge predictions (mean, max, gmean, etc.)
| # output batch of masks/labels
Output
Table of Contents
Quick start (Default Transforms)
Test
We support that you can use the following to test after defining the network.
Segmentation model wrapping [docstring]:
import patta as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
Note: the model must return keypoints in the format Tensor([x1, y1, ..., xn, yn])
Predict
We support that you can use the following to test when you have the static model: *.pdmodel
、*.pdiparams
、*.pdiparams.info
.
Load model [docstring]:
import patta as tta
model = tta.load_model(path='output/model')
Segmentation model wrapping [docstring]:
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
Use-Tools
Segmentation model [docstring]:
We recommend modifying the file seg.py
according to your own model.
python seg.py --model_path='output/model' \
--batch_size=16 \
--test_dataset='test.txt'
Note: Related to paddleseg
Advanced-Examples (DIY Transforms)
Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.Rotate90(angles=[0, 180]),
tta.Scale(scales=[1, 2, 4]),
tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)
for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
# augment image
augmented_image = transformer.augment_image(image)
# pass to model
model_output = model(augmented_image, another_input_data)
# reverse augmentation for mask and label
deaug_mask = transformer.deaugment_mask(model_output['mask'])
deaug_label = transformer.deaugment_label(model_output['label'])
# save results
labels.append(deaug_mask)
masks.append(deaug_label)
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)
Optional Transforms
Transform | Parameters | Values |
---|---|---|
HorizontalFlip | - | - |
VerticalFlip | - | - |
Rotate90 | angles | List[0, 90, 180, 270] |
Scale | scales interpolation |
List[float] "nearest"/"linear" |
Resize | sizes original_size interpolation |
List[Tuple[int, int]] Tuple[int,int] "nearest"/"linear" |
Add | values | List[float] |
Multiply | factors | List[float] |
FiveCrops | crop_height crop_width |
int int |
Aliases (Combos)
- flip_transform (horizontal + vertical flips)
- hflip_transform (horizontal flip)
- d4_transform (flips + rotation 0, 90, 180, 270)
- multiscale_transform (scale transform, take scales as input parameter)
- five_crop_transform (corner crops + center crop)
- ten_crop_transform (five crops + five crops on horizontal flip)
Merge-modes
- mean
- gmean (geometric mean)
- sum
- max
- min
- tsharpen (temperature sharpen with t=0.5)
Installation
PyPI:
# Use pip install PaTTA
$ pip install patta
or
# After downloading the whole dir
$ git clone https://github.com/AgentMaker/PaTTA.git
$ pip install PaTTA/
Run tests
# run test_transforms.py and test_base.py for test
python test/test_transforms.py
python test/test_base.py
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
patta-0.0.2.tar.gz
(12.1 kB
view details)
Built Distribution
patta-0.0.2-py3-none-any.whl
(11.0 kB
view details)
File details
Details for the file patta-0.0.2.tar.gz
.
File metadata
- Download URL: patta-0.0.2.tar.gz
- Upload date:
- Size: 12.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c1f31a4934de8414b43b6a4a324afeba931e69749073f34fb3a3cc3e6469a2c8 |
|
MD5 | 4caef6c6ce33c285c2a56eb755107d16 |
|
BLAKE2b-256 | bd374ddaaff552175eafd0879af20f7799a701f4d930b91a6bf621c68d154e27 |
File details
Details for the file patta-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: patta-0.0.2-py3-none-any.whl
- Upload date:
- Size: 11.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/3.7.3 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.59.0 CPython/3.9.2
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 46839a99170712d2a935ff09508a36aabf4419ce2654ce897748cfa7ad061e6f |
|
MD5 | bb8c54597827f53f83baf692f22a69e4 |
|
BLAKE2b-256 | 72fa3a30d550d7fc2a2decce111dbbe890dc14a30a18e3d8df3729aca961b041 |