State-of-the-art data augmentation search algorithms in PyTorch
Project description
MuarAugment
Description
MuarAugment is the easiest way to a state-of-the-art data augmentation pipeline.
It adapts the leading pipeline search algorithms, RandAugment[1] and the model uncertainty-based augmentation scheme[2] (called MuAugment here), and modifies them to work batch-wise, on the GPU. Kornia[3] and albumentations are used for batch-wise and item-wise transforms respectively.
If you are seeking SOTA data augmentation pipelines without laborious trial-and-error, MuarAugment is the package for you.
How to use
You can install MuarAugment
via PIP:
!pip install muaraugment
Examples
For MuAugment
, simply modify the training logic and train like normal.
In PyTorch Lightning
from muar.augmentations import BatchRandAugment, MuAugment
class LitModule(pl.LightningModule):
def __init__(self, n_tfms, magn, mean, std, n_compositions, n_selected):
...
rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
self.mu_transform = MuAugment(rand_augment, n_compositions, n_selected)
def training_step(self, batch, batch_idx):
self.mu_transform.setup(self)
input, target = self.mu_transform((batch['input'], batch['target']))
...
trainer = Trainer()
trainer.fit(model, datamodule)
In pure PyTorch
from muar.augmentations import BatchRandAugment, MuAugment
def train_fn(model):
rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
mu_transform = MuAugment(rand_augment, n_compositions, n_selected)
for epoch in range(N_EPOCHS):
for i,batch in enumerate(train_dataloader):
mu_transform.setup(model)
input, target = self.mu_transform(batch)
train_fn(model)
See the colab notebook tutorial (#1) for more detail on implementing MuAugment
.
RandAugment using Albumentations
MuarAugment
also contains a straightforward implementation of RandAugment using Albumentations:
class RandAugmentDataset(Dataset):
def __init__(self, N_TFMS=0, MAGN=0, stage='train', ...):
...
if stage == 'train':
self.rand_augment = AlbumentationsRandAugment(N_TFMS, MAGN)
else: self.rand_augment = None
def __getitem__(self, idx):
...
transform = get_transform(self.rand_augment, self.stage, self.size)
augmented = transform(image=image)['image']
...
def get_transform(rand_augment, stage='train', size=(28,28)):
if stage == 'train':
resize_tfm = [A.Resize(*size)]
rand_tfms = rand_augment() # returns a list of transforms
tensor_tfms = [A.Normalize(), ToTensorV2()]
return A.Compose(resize_tfm + rand_tfms + tensor_tfms)
...
See the colab notebook tutorial (#2) for more detail on AlbumentationsRandAugment
.
Tutorials
- MuAugment tutorial and implementation in a classification task (Colab Notebook)
- RandAugment tutorial in an end-to-end pipeline (Colab Notebook)
- Overview of data augmentation policy search algorithms (Medium)
Papers Referenced
- Cubuk, Ekin et al. "RandAugment: Practical data augmentation with no separate search," 2019, arXiv.
- Wu, Sen et al. "On the Generalization Effects of Linear Transformations in Data Augmentation," 2020, arXiv.
- Riba, Edgar et al. "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch," 2019, arXiv.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for MuarAugment-1.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7cff04e827f09ac82c313f04ca4fc75c8858a67732136b5027c0ec8706a49a3 |
|
MD5 | 4fd0462582d796f81795c5ab3294af01 |
|
BLAKE2b-256 | 15177cabcc66bce41c56a83ad825aac794c708b636c8bd02afc86eca10f8d79b |