Skip to main content

Images test time augmentation with PyTorch.

Project description


Image Test Time Augmentation with PyTorch!

Similar to what Data Augmentation is doing to the training set, the purpose of Test Time Augmentation is to perform random modifications to the test images. Thus, instead of showing the regular, “clean” images, only once to the trained model, we will show it the augmented images several times. We will then average the predictions of each corresponding image and take that as our final guess [1].

             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels

Table of Contents

  1. Quick Start
  2. Transforms
  3. Aliases
  4. Merge modes
  5. Installation

Quick start

Segmentation model wrapping:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
Classification model wrapping:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())
Keypoints model wrapping:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)

Note: the model must return keypoints in the format torch([x1, y1, ..., xn, yn])

Advanced Examples

Custom transform:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        

tta_model = tta.SegmentationTTAWrapper(model, transforms)
Custom model (multi-input / multi-output)
# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 

    # augment image
    augmented_image = transformer.augment_image(image)

    # pass to model
    model_output = model(augmented_image, another_input_data)

    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])

    # save results

# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)


Transform Parameters Values
HorizontalFlip - -
VerticalFlip - -
Rotate90 angles List[0, 90, 180, 270]
Scale scalesinterpolation List[float]"nearest"/"linear"
Resize sizesoriginal_sizeinterpolation List[Tuple[int, int]]Tuple[int,int]"nearest"/"linear"
Add values List[float]
Multiply factors List[float]
FiveCrops crop_heightcrop_width intint


  • flip_transform (horizontal + vertical flips)
  • hflip_transform (horizontal flip)
  • d4_transform (flips + rotation 0, 90, 180, 270)
  • multiscale_transform (scale transform, take scales as input parameter)
  • five_crop_transform (corner crops + center crop)
  • ten_crop_transform (five crops + five crops on horizontal flip)

Merge modes



$ pip install ttach


$ pip install git+

Run tests

docker build -f -t ttach:dev . && docker run --rm ttach:dev pytest -p no:cacheprovider

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Files for ttach, version 0.0.3
Filename, size File type Python version Upload date Hashes
Filename, size ttach-0.0.3-py3-none-any.whl (9.8 kB) File type Wheel Python version py3 Upload date Hashes View
Filename, size ttach-0.0.3.tar.gz (9.6 kB) File type Source Python version None Upload date Hashes View

Supported by

Pingdom Pingdom Monitoring Google Google Object Storage and Download Analytics Sentry Sentry Error logging AWS AWS Cloud computing DataDog DataDog Monitoring Fastly Fastly CDN DigiCert DigiCert EV certificate StatusPage StatusPage Status page