Skip to main content

State-of-the-art data augmentation search algorithms in PyTorch

Project description


MuarAugment

Description

MuarAugment is the easiest way to a state-of-the-art data augmentation pipeline.

It adapts the leading pipeline search algorithms, RandAugment[1] and the model uncertainty-based augmentation scheme[2] (called MuAugment here), and modifies them to work batch-wise, on the GPU. Kornia[3] and albumentations are used for batch-wise and item-wise transforms respectively.

If you are seeking SOTA data augmentation pipelines without laborious trial-and-error, MuarAugment is the package for you.

How to use

You can install MuarAugment via PIP:

!pip install muaraugment

Examples

For MuAugment, simply modify the training logic and train like normal.

In PyTorch Lightning

from muar.augmentations import BatchRandAugment, MuAugment

 class LitModule(pl.LightningModule):
     def __init__(self, n_tfms, magn, mean, std, n_compositions, n_selected):
        ...
        rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
        self.mu_transform = MuAugment(rand_augment, n_compositions, n_selected)

    def training_step(self, batch, batch_idx):
        self.mu_transform.setup(self)
        input, target = self.mu_transform((batch['input'], batch['target']))
        ...

trainer = Trainer()
trainer.fit(model, datamodule)

In pure PyTorch

from muar.augmentations import BatchRandAugment, MuAugment

def train_fn(model):

    rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
    mu_transform = MuAugment(rand_augment, n_compositions, n_selected)

    for epoch in range(N_EPOCHS):
        for i,batch in enumerate(train_dataloader):
            mu_transform.setup(model)
            input, target = self.mu_transform(batch)

train_fn(model)

See the colab notebook tutorial (#1) for more detail on implementing MuAugment.

RandAugment using Albumentations

MuarAugment also contains a straightforward implementation of RandAugment using Albumentations:

class RandAugmentDataset(Dataset):
    def __init__(self, N_TFMS=0, MAGN=0, stage='train', ...):
        ...
        if stage == 'train': 
            self.rand_augment = AlbumentationsRandAugment(N_TFMS, MAGN)
        else: self.rand_augment = None

    def __getitem__(self, idx):
        ...
        transform = get_transform(self.rand_augment, self.stage, self.size)
        augmented = transform(image=image)['image']
        ...

def get_transform(rand_augment, stage='train', size=(28,28)):
    if stage == 'train':
        resize_tfm = [A.Resize(*size)]
        rand_tfms = rand_augment() # returns a list of transforms
        tensor_tfms = [A.Normalize(), ToTensorV2()]
        return A.Compose(resize_tfm + rand_tfms + tensor_tfms)
    ...

See the colab notebook tutorial (#2) for more detail on AlbumentationsRandAugment.

Tutorials

  1. MuAugment tutorial and implementation in a classification task (Colab Notebook)
  2. RandAugment tutorial in an end-to-end pipeline (Colab Notebook)
  3. Overview of data augmentation policy search algorithms (Medium)

Papers Referenced

  1. Cubuk, Ekin et al. "RandAugment: Practical data augmentation with no separate search," 2019, arXiv.
  2. Wu, Sen et al. "On the Generalization Effects of Linear Transformations in Data Augmentation," 2020, arXiv.
  3. Riba, Edgar et al. "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch," 2019, arXiv.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

MuarAugment-1.1.1.tar.gz (12.0 kB view details)

Uploaded Source

Built Distribution

MuarAugment-1.1.1-py3-none-any.whl (20.3 kB view details)

Uploaded Python 3

File details

Details for the file MuarAugment-1.1.1.tar.gz.

File metadata

  • Download URL: MuarAugment-1.1.1.tar.gz
  • Upload date:
  • Size: 12.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.0

File hashes

Hashes for MuarAugment-1.1.1.tar.gz
Algorithm Hash digest
SHA256 a61022ada173659991538fce108427c588f8757cb7b187f4649e334119966ae4
MD5 129981062ccc40538e58367bcfeff50d
BLAKE2b-256 bba1811d25a12d62cf9d0bfd7e5142e0d6d83655f3d1ac00367a43b6138542fe

See more details on using hashes here.

File details

Details for the file MuarAugment-1.1.1-py3-none-any.whl.

File metadata

  • Download URL: MuarAugment-1.1.1-py3-none-any.whl
  • Upload date:
  • Size: 20.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.0

File hashes

Hashes for MuarAugment-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 e7cff04e827f09ac82c313f04ca4fc75c8858a67732136b5027c0ec8706a49a3
MD5 4fd0462582d796f81795c5ab3294af01
BLAKE2b-256 15177cabcc66bce41c56a83ad825aac794c708b636c8bd02afc86eca10f8d79b

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page