Skip to main content

CutMuxImageDataGenerator For Keras

Project description

CutMixImageDataGenerator (Keras)

GitHub release (latest by date)

Paper: CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features


Install

$ pip install cutmix-keras

How To Use

class CutMixImageDataGenerator():
    def __init__(self, generator1, generator2, img_size, batch_size):
        self.batch_index = 0
        self.samples = generator1.samples
        self.class_indices = generator1.class_indices
        self.generator1 = generator1
        self.generator2 = generator2
        . . .
  • generator1, generator2 need same generator applied flow method

  • generator1, generator2 need shuffle=True
    If shuffle=False, This generator cutmix with same images.
    So there would no augmentation

  • Why are there two same generators? (generator1, generator2)
    --> To Solve Reference Problem

Using Example

# (some codes) ...
from cutmix_keras import CutMixImageDataGenerator  # Import CutMix


train_datagen = ImageDataGenerator(
    rescale=1./255,
)

train_generator1 = train_datagen.flow_from_dataframe(
    dataframe=X_train,
    directory=IMG_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    x_col='X_Column',
    y_col='Y_Column',
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=True,  # Required
)

train_generator2 = train_datagen.flow_from_dataframe(
    dataframe=X_train,
    directory=IMG_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    x_col='X_Column',
    y_col='Y_Column',
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=True,  # Required
)

# CutMixImageDataGenerator
train_generator = CutMixImageDataGenerator(
    generator1=train_generator1,
    generator2=train_generator2,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Built Distribution

cutmix_keras-1.0.0-py3-none-any.whl (3.6 kB view hashes)

Uploaded py3

Supported by

AWS AWS Cloud computing Datadog Datadog Monitoring Facebook / Instagram Facebook / Instagram PSF Sponsor Fastly Fastly CDN Google Google Object Storage and Download Analytics Huawei Huawei PSF Sponsor Microsoft Microsoft PSF Sponsor NVIDIA NVIDIA PSF Sponsor Pingdom Pingdom Monitoring Salesforce Salesforce PSF Sponsor Sentry Sentry Error logging StatusPage StatusPage Status page