Skip to main content

State-of-the-art data augmentation search algorithms in PyTorch

Project description


MuarAugment

Description

MuarAugment is the easiest way to a state-of-the-art data augmentation pipeline.

It adapts the leading pipeline search algorithms, RandAugment[1] and the model uncertainty-based augmentation scheme[2] (called MuAugment here), and modifies them to work batch-wise, on the GPU. Kornia[3] and albumentations are used for batch-wise and item-wise transforms respectively.

If you are seeking SOTA data augmentation pipelines without laborious trial-and-error, MuarAugment is the package for you.

How to use

You can install MuarAugment via PIP:

!pip install muaraugment

Examples

For MuAugment, simply modify the training logic and train like normal.

In PyTorch Lightning

from muar.augmentations import BatchRandAugment, MuAugment

 class LitModule(pl.LightningModule):
     def __init__(self, n_tfms, magn, mean, std, n_compositions, n_selected):
        ...
        rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
        self.mu_transform = MuAugment(rand_augment, n_compositions, n_selected)

    def training_step(self, batch, batch_idx):
        self.mu_transform.setup(self)
        input, target = self.mu_transform((batch['input'], batch['target']))
        ...

trainer = Trainer()
trainer.fit(model, datamodule)

In pure PyTorch

from muar.augmentations import BatchRandAugment, MuAugment

def train_fn(model):

    rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
    mu_transform = MuAugment(rand_augment, n_compositions, n_selected)

    for epoch in range(N_EPOCHS):
        for i,batch in enumerate(train_dataloader):
            mu_transform.setup(model)
            input, target = self.mu_transform(batch)

train_fn(model)

See the colab notebook tutorial (#1) for more detail on implementing MuAugment.

RandAugment using Albumentations

MuarAugment also contains a straightforward implementation of RandAugment using Albumentations:

class RandAugmentDataset(Dataset):
    def __init__(self, N_TFMS=0, MAGN=0, stage='train', ...):
        ...
        if stage == 'train': 
            self.rand_augment = AlbumentationsRandAugment(N_TFMS, MAGN)
        else: self.rand_augment = None

    def __getitem__(self, idx):
        ...
        transform = get_transform(self.rand_augment, self.stage, self.size)
        augmented = transform(image=image)['image']
        ...

def get_transform(rand_augment, stage='train', size=(28,28)):
    if stage == 'train':
        resize_tfm = [A.Resize(*size)]
        rand_tfms = rand_augment() # returns a list of transforms
        tensor_tfms = [A.Normalize(), ToTensorV2()]
        return A.Compose(resize_tfm + rand_tfms + tensor_tfms)
    ...

See the colab notebook tutorial (#2) for more detail on AlbumentationsRandAugment.

Tutorials

  1. MuAugment tutorial and implementation in a classification task (Colab Notebook)
  2. RandAugment tutorial in an end-to-end pipeline (Colab Notebook)
  3. Overview of data augmentation policy search algorithms (Medium)

Papers Referenced

  1. Cubuk, Ekin et al. "RandAugment: Practical data augmentation with no separate search," 2019, arXiv.
  2. Wu, Sen et al. "On the Generalization Effects of Linear Transformations in Data Augmentation," 2020, arXiv.
  3. Riba, Edgar et al. "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch," 2019, arXiv.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

MuarAugment-1.1.1.tar.gz (12.0 kB view hashes)

Uploaded source

Built Distribution

MuarAugment-1.1.1-py3-none-any.whl (20.3 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page