Skip to main content

Discover augmentation strategies tailored for your data

Project description

DeepAugment

pypi License: MIT Code style: black

DeepAugment discovers optimized augmentation strategies tailored for your images. It uses Bayesian Optimization for optimizing hyperparameters for augmentation. The tool:

  1. boosts deep learning model accuracy (shown 5.2% accuracy increase (36% decrease in error) for CIFAR-10 on WRN-28-10 compared to no augmentation)
  2. saves times by automating the process

Resources: slides

Installation/Usage

$ pip install deepaugment

Simple usage (with any dataset)

from deepaugment.deepaugment import DeepAugment

deepaug = DeepAugment(my_images, my_labels)

best_policies = deepaug.optimize(300)

Simple usage (with cifar-10 dataset)

deepaug = DeepAugment("cifar10")

best_policies = deepaug.optimize(300)

Advanced usage (by changing configurations, and with fashion-mnist dataset)

from keras.datasets import fashion_mnist

# my configuration
my_config = {
    "model": "basiccnn",
    "method": "bayesian_optimization",
    "train_set_size": 2000,
    "opt_samples": 3,
    "opt_last_n_epochs": 3,
    "opt_initial_points": 10,
    "child_epochs": 50,
    "child_first_train_epochs": 0,
    "child_batch_size": 64
}

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# X_train.shape -> (N, M, M, 3)
# y_train.shape -> (N)
deepaug = DeepAugment(iamges=x_train, labels=y_train, config=my_config)

best_policies = deepaug.optimize(300)

Results

CIFAR-10 best policies tested on WRN-28-10

  • Method: Wide-ResNet-28-10 trained with CIFAR-10 augmented images by best found policies, and with unaugmented images (everything else same).
  • Result: 5.2% accuracy increase by DeepAugment

How it works

Package consists three main components: controller, augmenter, and child model. Overal workflow is that controller samples new augmentation policies, augmenter transforms images by the new policy, and child model is trained from scratch by augmented images. Then, a reward is calculated from child model's validation accuracy curve by the formula as explained at (reward function section). This reward is returned back to controller, and it updates its internal and samples a new augmentation policy, returning to the beginning of the cycle.

Controller might be set to use Bayesian Optimization (defaul), or Random Search. If Bayesian Optimization set, it samples new policies by a Random Forest Estimator.

simplified_workflow

Augmentation policy

A policy describes the augmentation will be applied on a dataset. Each policy consists variables for two augmentation types, their magnitude and the portion of the data to be augmented. An example policy is as following:

example policy

There are currently 20 types of augmentation techniques (above, right) that each aug. type variable can take. All techniques are (this list might grow in later versions):

AUG_TYPES = [ "crop", "gaussian-blur", "rotate", "shear", "translate-x", "translate-y", "sharpen", "emboss", "additive-gaussian-noise", "dropout", "coarse-dropout", "gamma-contrast", "brighten", "invert", "fog", "clouds", "add-to-hue-and-saturation", "coarse-salt-pepper", "horizontal-flip", "vertical-flip"]

Child model

child-cnn

Reward function

Reward function is calculated as mean of K highest validation accuracies of the child model which is not smaller than corresponding training accuracy by 0.05. K can be determined by the user by updating opt_last_n_epochs key in config dictionary as argument to DeepAugment() class (K is 3 by default).

Data pipeline

data-pipeline-2 data-pipeline-1

Class diagram

classes_deepaugment

Package diagram

package-diagram --------

Contact

Baris Ozmen, hbaristr@gmail.com

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepaugment-0.12.3.tar.gz (24.8 kB view details)

Uploaded Source

Built Distribution

deepaugment-0.12.3-py2.py3-none-any.whl (29.1 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file deepaugment-0.12.3.tar.gz.

File metadata

  • Download URL: deepaugment-0.12.3.tar.gz
  • Upload date:
  • Size: 24.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.12.3.tar.gz
Algorithm Hash digest
SHA256 0516e7dd144287cd024fcbfc04813e330d75b0f0a13e721e821fbc96aa585900
MD5 005ff202ea20ac58f1316cb45167b76f
BLAKE2b-256 41f3ac7b1c8cdbce4eabe37b902dd8f69a969c1144303dca076d837e5769c26d

See more details on using hashes here.

File details

Details for the file deepaugment-0.12.3-py2.py3-none-any.whl.

File metadata

  • Download URL: deepaugment-0.12.3-py2.py3-none-any.whl
  • Upload date:
  • Size: 29.1 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.12.3-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 286ab6373c5cebc60057ab8bca1d7f8f0fe4e71039a68e149aa1574315949695
MD5 1275ea038f948d0cf20ae0990402d219
BLAKE2b-256 6e9a37d2499f5f29a3ce6b11c6b30ba6c951d952888f4aa1f7fbb7e902c1a355

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page