Skip to main content

Images test time augmentation with PyTorch.

Project description

TTAch

Image Test Time Augmentation with PyTorch!

Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].

           Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

Table of Contents

  1. Quick Start
  2. Transforms
  3. Aliases
  4. Merge modes
  5. Installation

Quick start

Segmentation model wrapping:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])

Advanced Examples

Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 

    # augment image
    augmented_image = transformer.augment_image(image)

    # pass to model
    model_output = model(augmented_image, another_input_data)

    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])

    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)

# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms

Transform Parameters Values
HorizontalFlip - -
VerticalFlip - -
Rotate90 angles List[0, 90, 180, 270]
Scale scales
interpolation
List[float]
"nearest"/"linear"
Resize sizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
Add values List[float]
Multiply factors List[float]
FiveCrops crop_height
crop_width
int
int

Aliases

  • flip_transform (horizontal + vertical flips)
  • hflip_transform (horizontal flip)
  • d4_transform (flips + rotation 0, 90, 180, 270)
  • multiscale_transform (scale transform, take scales as input parameter)
  • five_crop_transform (corner crops + center crop)
  • ten_crop_transform (five crops + five crops on horizontal flip)

Merge modes

Installation

PyPI:

$ pip install ttach

Source:

$ pip install git+https://github.com/qubvel/ttach

Run tests

docker build -f Dockerfile.dev -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ttach-0.0.3.tar.gz (9.6 kB view details)

Uploaded Source

Built Distribution

ttach-0.0.3-py3-none-any.whl (9.8 kB view details)

Uploaded Python 3

File details

Details for the file ttach-0.0.3.tar.gz.

File metadata

  • Download URL: ttach-0.0.3.tar.gz
  • Upload date:
  • Size: 9.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.10

File hashes

Hashes for ttach-0.0.3.tar.gz
Algorithm Hash digest
SHA256 120c4dd881feb0e9c8dd63b154f2655891c3e20689b68a94d162bfd5557bcb48
MD5 4b1cc1a69a01fbd221dfd965e0859026
BLAKE2b-256 915d4c49e0eca4206bc25eff4ba89cee51b781466e2e3aad2f1057fd5d2634be

See more details on using hashes here.

File details

Details for the file ttach-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: ttach-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 9.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/49.1.0 requests-toolbelt/0.9.1 tqdm/4.47.0 CPython/3.6.10

File hashes

Hashes for ttach-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 7000bb4334f856b0c79a341df386c92f1c76faf091043cc3cd7f541d2149faf8
MD5 227e7146fa53a6876527ed3121a87994
BLAKE2b-256 8da3ee48a184a185c1897c582c72240c2c8a0d0aeb5f8051a71d4e4cd930c52d

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page