State-of-the-art data augmentation search algorithms in PyTorch
Project description
MuarAugment
Description
MuarAugment is the easiest way to a state-of-the-art data augmentation pipeline.
It adapts the leading pipeline search algorithms, RandAugment[1] and the model uncertainty-based augmentation scheme[2] (called MuAugment here), and modifies them to work batch-wise, on the GPU. Kornia[3] and albumentations are used for batch-wise and item-wise transforms respectively.
If you are seeking SOTA data augmentation pipelines without laborious trial-and-error, MuarAugment is the package for you.
How to use
You can install MuarAugment
via PIP:
!pip install muaraugment
Examples
For MuAugment
, simply modify the training logic and train like normal.
In PyTorch Lightning
from muar.augmentations import BatchRandAugment, MuAugment
class LitModule(pl.LightningModule):
def __init__(self, n_tfms, magn, mean, std, n_compositions, n_selected):
...
rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
self.mu_transform = MuAugment(rand_augment, n_compositions, n_selected)
def training_step(self, batch, batch_idx):
self.mu_transform.setup(self)
input, target = self.mu_transform((batch['input'], batch['target']))
...
trainer = Trainer()
trainer.fit(model, datamodule)
In pure PyTorch
from muar.augmentations import BatchRandAugment, MuAugment
def train_fn(model):
rand_augment = BatchRandAugment(n_tfms, magn, mean, std)
mu_transform = MuAugment(rand_augment, n_compositions, n_selected)
for epoch in range(N_EPOCHS):
for i,batch in enumerate(train_dataloader):
mu_transform.setup(model)
input, target = self.mu_transform(batch)
train_fn(model)
See the colab notebook tutorial (#1) for more detail on implementing MuAugment
.
RandAugment using Albumentations
MuarAugment
also contains a straightforward implementation of RandAugment using Albumentations:
class RandAugmentDataset(Dataset):
def __init__(self, N_TFMS=0, MAGN=0, stage='train', ...):
...
if stage == 'train':
self.rand_augment = AlbumentationsRandAugment(N_TFMS, MAGN)
else: self.rand_augment = None
def __getitem__(self, idx):
...
transform = get_transform(self.rand_augment, self.stage, self.size)
augmented = transform(image=image)['image']
...
def get_transform(rand_augment, stage='train', size=(28,28)):
if stage == 'train':
resize_tfm = [A.Resize(*size)]
rand_tfms = rand_augment() # returns a list of transforms
tensor_tfms = [A.Normalize(), ToTensorV2()]
return A.Compose(resize_tfm + rand_tfms + tensor_tfms)
...
See the colab notebook tutorial (#2) for more detail on AlbumentationsRandAugment
.
Tutorials
- MuAugment tutorial and implementation in a classification task (Colab Notebook)
- RandAugment tutorial in an end-to-end pipeline (Colab Notebook)
- Overview of data augmentation policy search algorithms (Medium)
Papers Referenced
- Cubuk, Ekin et al. "RandAugment: Practical data augmentation with no separate search," 2019, arXiv.
- Wu, Sen et al. "On the Generalization Effects of Linear Transformations in Data Augmentation," 2020, arXiv.
- Riba, Edgar et al. "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch," 2019, arXiv.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file MuarAugment-1.1.1.tar.gz
.
File metadata
- Download URL: MuarAugment-1.1.1.tar.gz
- Upload date:
- Size: 12.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a61022ada173659991538fce108427c588f8757cb7b187f4649e334119966ae4 |
|
MD5 | 129981062ccc40538e58367bcfeff50d |
|
BLAKE2b-256 | bba1811d25a12d62cf9d0bfd7e5142e0d6d83655f3d1ac00367a43b6138542fe |
File details
Details for the file MuarAugment-1.1.1-py3-none-any.whl
.
File metadata
- Download URL: MuarAugment-1.1.1-py3-none-any.whl
- Upload date:
- Size: 20.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.4.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.0 CPython/3.9.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e7cff04e827f09ac82c313f04ca4fc75c8858a67732136b5027c0ec8706a49a3 |
|
MD5 | 4fd0462582d796f81795c5ab3294af01 |
|
BLAKE2b-256 | 15177cabcc66bce41c56a83ad825aac794c708b636c8bd02afc86eca10f8d79b |