Skip to main content

Discover augmentation strategies tailored for your data

Project description

# DeepAugment

![pypi](https://img.shields.io/pypi/v/deepaugment.svg?style=flat)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)

DeepAugment discovers best augmentation strategies tailored for your images. It optimizes augmentation hyperparameters using Bayesian Optimization, which is widely used for hyperparameter tuning. The tool:
- boosts deep learning model accuracy 5% compared to models not using augmentation.
- saves times by automating the process


Resources: [slides](https://docs.google.com/presentation/d/1toRUTT9X26ACngr6DXCKmPravyqmaGjy-eIU5cTbG1A/edit#slide=id.g4cc092dbc6_0_0)

## Installation/Usage
```console
$ pip install deepaugment
```


Simple usage (with any dataset)
```Python
from deepaugment.deepaugment import DeepAugment

deepaug = DeepAugment(my_images, my_labels)

best_policies = deepaug.optimize(300)
```

Simple usage (with cifar-10 dataset)
```Python
deepaug = DeepAugment("cifar10")

best_policies = deepaug.optimize(300)
```


Advanced usage (by changing configurations, and with fashion-mnist dataset)
```Python
from keras.datasets import fashion_mnist

# my configuration
my_config = {
"model": "basiccnn",
"method": "bayesian_optimization",
"train_set_size": 2000,
"opt_samples": 3,
"opt_last_n_epochs": 3,
"opt_initial_points": 10,
"child_epochs": 50,
"child_first_train_epochs": 0,
"child_batch_size": 64
}

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# X_train.shape -> (N, M, M, 3)
# y_train.shape -> (N)
deepaug = DeepAugment(iamges=x_train, labels=y_train, config=my_config)

best_policies = deepaug.optimize(300)
```


## Results
### CIFAR-10 best policies tested on WRN-28-10
- Method: Wide-ResNet-28-10 trained with CIFAR-10 augmented images by best found policies, and with unaugmented images (everything else same).
- Result: **5.2% accuracy increase** by DeepAugment

<img src="https://user-images.githubusercontent.com/14996155/52544784-e0541900-2d67-11e9-93db-0b8b192f5b37.png" width="400"> <img src="https://user-images.githubusercontent.com/14996155/52545044-63c23a00-2d69-11e9-9879-3d7bcb8f88f4.png" width="400">

## How it works?

DeepAugment consists three main components: controller, augmenter, and child model. Controller samples new augmentation policies ((see below)[#Augmentation policy]), augmenter transforms images by the new policy, and child model is trained from scratch by augmented images. Then, a reward is calculated from child model's validation accuracy curve by the formula as explained at (reward function section)[#Reward function]. This reward is returned back to controller, and it updates its internal and samples a new augmentation policy, returning to the beginning of the cycle (iteration).

Controller might be set to use Bayesian Optimization (defaul), or Random Search. If Bayesian Optimization set, it samples new policies by a Random Forest Estimator, which is updated at each iteration.

<img width="600" alt="simplified_workflow" src="https://user-images.githubusercontent.com/14996155/52587711-797a4280-2def-11e9-84f8-2368fd709ab9.png">

### Augmentation policy

A policy describes the augmentation will be applied on a dataset. Each policy consists variables for two augmentation types, their magnitude and the portion of the data to be augmented. An example policy is as following:

<img width="400" alt="example policy" src="https://user-images.githubusercontent.com/14996155/52595719-59ed1500-2e03-11e9-9a40-a79462006924.png">

There are currently 20 types of augmentation techniques (above, right) that each aug. type variable can take. All techniques are (this list might grow in later versions):
```
AUG_TYPES = [ "crop", "gaussian-blur", "rotate", "shear", "translate-x", "translate-y", "sharpen", "emboss", "additive-gaussian-noise", "dropout", "coarse-dropout", "gamma-contrast", "brighten", "invert", "fog", "clouds", "add-to-hue-and-saturation", "coarse-salt-pepper", "horizontal-flip", "vertical-flip"]
```
### Child model
<img width="800" alt="child-cnn" src="https://user-images.githubusercontent.com/14996155/52545277-10e98200-2d6b-11e9-9639-48b671711eba.png">

### Reward function
Reward function is calculated as mean of K highest validation accuracies of the child model which is not smaller than corresponding training accuracy by 0.05. K can be determined by the user by updating `opt_last_n_epochs` key in config dictionary as argument to `DeepAugment()` class (K is 3 by default).

## Data pipeline
<img width="600" alt="data-pipeline-1" src="https://user-images.githubusercontent.com/14996155/52740937-0c9ab000-2f89-11e9-9e94-beca71caed41.png">
<img width="600" alt="data-pipeline-2" src="https://user-images.githubusercontent.com/14996155/52740938-0d334680-2f89-11e9-8d68-117d139d9ab8.png">

## Class diagram
![classes_deepaugment](https://user-images.githubusercontent.com/14996155/52743629-4969a580-2f8f-11e9-8eb2-35aa1af161bb.png)

## Package diagram
<img width="600" alt="package-diagram" src="https://user-images.githubusercontent.com/14996155/52743630-4a023c00-2f8f-11e9-9b12-32b2ded6071b.png">
--------

## Contact
Baris Ozmen, hbaristr@gmail.com


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepaugment-0.10.0.tar.gz (24.3 kB view details)

Uploaded Source

Built Distribution

deepaugment-0.10.0-py2.py3-none-any.whl (28.6 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file deepaugment-0.10.0.tar.gz.

File metadata

  • Download URL: deepaugment-0.10.0.tar.gz
  • Upload date:
  • Size: 24.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.10.0.tar.gz
Algorithm Hash digest
SHA256 3adcdfe268e7b3931e8b65484dfe834ec0d2c1c6ea37f411d68f365038cd9a64
MD5 9e1550c5f8285236b4c9725c4cef40b4
BLAKE2b-256 ad7b8d625715712a34d43e74c7faa47012675a8aa8934b124205301fb3e6adf8

See more details on using hashes here.

File details

Details for the file deepaugment-0.10.0-py2.py3-none-any.whl.

File metadata

  • Download URL: deepaugment-0.10.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 28.6 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.10.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 eef452436a7e4c5cf724bc87b845398fb82acdf5b2563d2ef8bc108e994cf677
MD5 466b0823e145aa00e0c9aebf734e3aaa
BLAKE2b-256 2aa3296d713f32a4d2e162a1cfb5a6de771324f6640995a3c40c33287c8e7ab1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page