Skip to main content

Discover augmentation strategies tailored for your data

Project description

# DeepAugment

<img width="400" alt="concise_workflow" src="https://user-images.githubusercontent.com/14996155/52543808-6d47a400-2d61-11e9-8df7-8271872ba0ad.png">

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)

DeepAugment discovers best augmentation strategies tailored for your images. It optimizes augmentation hyperparameters using Bayesian Optimization, which is widely used for hyperparameter tuning. The tool:
- boosts deep learning model accuracy 5% compared to models not using augmentation.
- saves times by automating the process


Resources: [slides](https://docs.google.com/presentation/d/1toRUTT9X26ACngr6DXCKmPravyqmaGjy-eIU5cTbG1A/edit#slide=id.g4cc092dbc6_0_0)

## Installation/Usage
```console
$ pip install deepaugment
```


Simple usage (with any dataset)
```Python
from deepaugment.deepaugment import DeepAugment

deepaug = DeepAugment(my_data, my_labels)

best_policies = deepaug.optimize(300)
```

Simple usage (with cifar-10 dataset)
```Python
deepaug = DeepAugment("cifar10")

best_policies = deepaug.optimize(300)
```


Advanced usage (by changing configurations, and with fashion-mnist dataset)
```Python
from keras.datasets import fashion_mnist

# my configuration
my_config = {
"model": "basiccnn",
"method": "bayesian_optimization",
"train_set_size": 2000,
"opt_samples": 3,
"opt_last_n_epochs": 3,
"opt_initial_points": 10,
"child_epochs": 50,
"child_first_train_epochs": 0,
"child_batch_size": 64
}

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# X_train.shape -> (N, M, M, 3)
# y_train.shape -> (N)
deepaug = DeepAugment(data=x_train, labels=y_train, config=my_config)

best_policies = deepaug.optimize(300)
```


## Results
### CIFAR-10 best policies tested on WRN-28-10
- Method: Wide-ResNet-28-10 trained with CIFAR-10 augmented images by best found policies, and with unaugmented images (everything else same).
- Result: **5.2% accuracy increase** by DeepAugment

<img src="https://user-images.githubusercontent.com/14996155/52544784-e0541900-2d67-11e9-93db-0b8b192f5b37.png" width="400"> <img src="https://user-images.githubusercontent.com/14996155/52545044-63c23a00-2d69-11e9-9879-3d7bcb8f88f4.png" width="400">

## How it works?

<img width="600" alt="simplified_workflow" src="https://user-images.githubusercontent.com/14996155/52587711-797a4280-2def-11e9-84f8-2368fd709ab9.png">

DeepAugment consists three main components: Controller, Augmenter, and Child model. Controller samples new augmentation policies ()

### What is an augmentation policy?

A policy describes the augmentation will be applied on a dataset. Each policy consists variables for two augmentation types, their magnitude and the portion of the data to be augmented. An example policy is as following:
<img width="400" alt="example policy" src="https://user-images.githubusercontent.com/14996155/52595719-59ed1500-2e03-11e9-9a40-a79462006924.png">

There are currently 20 types of augmentation techniques (above, right) that each aug. type variable can take. All techniques are (this list might grow in later versions):
```
AUG_TYPES = [ "crop", "gaussian-blur", "rotate", "shear", "translate-x", "translate-y", "sharpen", "emboss", "additive-gaussian-noise", "dropout", "coarse-dropout", "gamma-contrast", "brighten", "invert", "fog", "clouds", "add-to-hue-and-saturation", "coarse-salt-pepper", "horizontal-flip", "vertical-flip"]
```


DeepAugment working method can be dissected into three areas:
1. Search space of augmentation
2. Optimizer
3. Child model

### 1. Search space of augmentation
### 2. Optimizer
### 3. Child model
<img width="800" alt="child-cnn" src="https://user-images.githubusercontent.com/14996155/52545277-10e98200-2d6b-11e9-9639-48b671711eba.png">


### Repo Organization

├── LICENSE
├── Makefile <- Makefile with commands like `make data` or `make train`
├── README.md <- The top-level README for developers using this project.
├── data
│   ├── external <- Data from third party sources.
│   ├── interim <- Intermediate data that has been transformed.
│   ├── processed <- The final, canonical data sets for modeling.
│   └── raw <- The original, immutable data dump.

├── docs <- A default Sphinx project; see sphinx-doc.org for details

├── models <- Trained and serialized models, model predictions, or model summaries

├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
│ the creator's initials, and a short `-` delimited description, e.g.
│ `1.0-jqp-initial-data-exploration`.

├── references <- Data dictionaries, manuals, and all other explanatory materials.

├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
│   └── figures <- Generated graphics and figures to be used in reporting

├── requirements.txt <- The requirements file for reproducing the analysis environment, e.g.
│ generated with `pip freeze > requirements.txt`

├── setup.py <- makes project pip installable (pip install -e .) so src can be imported
├── src <- Source code for use in this project.
│   ├── __init__.py <- Makes src a Python module
│ │
│   ├── data <- Scripts to download or generate data
│   │   └── make_dataset.py
│ │
│   ├── features <- Scripts to turn raw data into features for modeling
│   │   └── build_features.py
│ │
│   ├── models <- Scripts to train models and then use trained models to make
│ │ │ predictions
│   │   ├── predict_model.py
│   │   └── train_model.py
│ │
│   └── visualization <- Scripts to create exploratory and results oriented visualizations
│   └── visualize.py

└── tox.ini <- tox file with settings for running tox; see tox.testrun.org


--------

<p><small>Project based on the <a target="_blank" href="https://drivendata.github.io/cookiecutter-data-science/">cookiecutter data science project template</a>. #cookiecutterdatascience</small></p>

## Contact
Baris Ozmen, hbaristr@gmail.com


Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepaugment-0.7.0.tar.gz (28.9 kB view details)

Uploaded Source

Built Distribution

deepaugment-0.7.0-py2.py3-none-any.whl (32.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file deepaugment-0.7.0.tar.gz.

File metadata

  • Download URL: deepaugment-0.7.0.tar.gz
  • Upload date:
  • Size: 28.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.7.0.tar.gz
Algorithm Hash digest
SHA256 37c1e7bca3b130a895f0a232a99796c78ef51c583be034ae336f8e72500a7ae0
MD5 36c77c28ce9762a6eebb4ada9a666354
BLAKE2b-256 e7091eeca7b5089b522b997f986321db84fee42dc04ab5aaf28b2586ba119274

See more details on using hashes here.

File details

Details for the file deepaugment-0.7.0-py2.py3-none-any.whl.

File metadata

  • Download URL: deepaugment-0.7.0-py2.py3-none-any.whl
  • Upload date:
  • Size: 32.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/1.12.1 pkginfo/1.4.2 requests/2.21.0 setuptools/40.6.3 requests-toolbelt/0.8.0 tqdm/4.28.1 CPython/3.6.5

File hashes

Hashes for deepaugment-0.7.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 571bd1f32f8989d59dc2ea8a1b1d6d47ecbdb4b8069d2b9729c27845e821a2f0
MD5 7a7b3af4dca04d8345c451236b36d51b
BLAKE2b-256 daf3012cf390209973bb14f22edf1cec9d783f42224a5050a4f2fb8950969422

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page