Skip to main content

The unified corpus building environment for Language Models.

Project description

Differentiable RandAugment

Optimize RandAugment with differentiable operations

build PyPI PyPI - Python Version PyPI - Format PyPI - License codecov CodeFactor

Table of Contents

Introduction

Differentiable RandAugment is a differentiable version of RandAugment. The original paper proposed to find optimal parameters by using grid search. Instead, this library supports differentiable operations to calculate gradient of the magnitude parameter and optimize it. See getting started.

Installation

To install the latest version from PyPI:

$ pip install -U differentiable_randaugment

Or you can install from source by cloning the repository and running:

$ git clone https://github.com/affjljoo3581/Differentiable-RandAugment.git
$ cd Differentiable-RandAugment
$ python setup.py install

Dependencies

  • opencv_python
  • torch>=1.7
  • albumentations
  • numpy

Getting Started

First, create RandAugmentModule with your desired number of operations. This module is a differentiable and torch.Tensor calculable version of RandAugment policy. Using this module, you can train the policy as one of the neural-network model. Note that randomly selected num_ops operations will be applied to the images.

  from differentiable_randaugment import RandAugmentModule

  augmentor = RandAugmentModule(num_ops=2)

Now you need to perform the module to the images. Usually augmentations are applied in Dataset. That is, the operations use np.ndarray images. However, it cannot calculate the gradients for image and magnitude parameter (because the entire optimization procedure is based on torch.Tensors). To resolve this, you should apply this module to torch.Tensor images rather than np.ndarray.

  for inputs, labels in train_dataloader:
      inputs = inputs.cuda()
      logits = model(augmentor(inputs))
      ...

Of course, other augmentations should be removed from preprocessing:

  transform = Compose([
      Resize(...),
      Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      ToTensorV2(),
  ])

And lastly, create an optimizer with this module parameters. We recommend to use different learning rate for the model and the augmentor:

  param_groups = [
      {"params": augmentor.parameters(), "lr": 10 * learning_rate},
      {"params": model.parameters(), "lr": learning_rate},
  ]
  optimizer = optim.Adam(param_groups)

Now the RandAugment policy will be trained with your prediction model.

After training RandAugmentModule, get the trained optimal magnitude value by calling augmentor.get_magnitude() and use the magnitude as follows:

  from differentiable_randaugment import RandAugment

  transform = Compose([
      Resize(...),
      RandAugment(num_ops=..., magnitude=...),
      Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      ToTensorV2(),
  ])
  dataset = Dataset(..., transform=transform)

While RandAugment is an extension of albumentations, you can combine other augmentations in albumentations with this class.

Support Operations

Differentiable RandAugment supports 14 operations described in the original paper. The below table shows the detailed differential specification of each operation.

Input Image Magnitude
Identity
ShearX
ShearY
TranslateX
TranslateY
Rotate
AutoContrast
Equalize
Solarize
Posterize
Contrast
Color
Brightness
Sharpness

License

Differentiable RandAugment is Apache-2.0 Licensed.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

differentiable_randaugment-0.1.0.tar.gz (11.8 kB view details)

Uploaded Source

Built Distribution

differentiable_randaugment-0.1.0-py3-none-any.whl (17.5 kB view details)

Uploaded Python 3

File details

Details for the file differentiable_randaugment-0.1.0.tar.gz.

File metadata

  • Download URL: differentiable_randaugment-0.1.0.tar.gz
  • Upload date:
  • Size: 11.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.7.9

File hashes

Hashes for differentiable_randaugment-0.1.0.tar.gz
Algorithm Hash digest
SHA256 dd360379854aa6c3cc9ad817eeafdae6c4ef9c21b794e860578b470fd823e668
MD5 78ce562f907e81137e1ceb0a5c3fa4dd
BLAKE2b-256 7002696ded95e8ddef0304a361175c6af0093e884b3b963c2bf76bfd4adfd54a

See more details on using hashes here.

File details

Details for the file differentiable_randaugment-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: differentiable_randaugment-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 17.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/47.1.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.7.9

File hashes

Hashes for differentiable_randaugment-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6ff3edaefbae2401427710e6beaa9f701e7e36680aa5f72a11840279b16973bf
MD5 92f510817a85dbbfc27f6b9ecea055bd
BLAKE2b-256 affe5b7654858e32ffaa6a39f5461e016a9665e9dc1fb792a3f4635544a5e3a0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page