Skip to main content

A library that includes Keras 3 preprocessing and augmentation layers

Project description

KerasAug

Keras GitHub Workflow Status codecov PyPI PyPI - Downloads Open in HF Spaces

Description

KerasAug is a library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.

object_detection.gif semantic_segmentation.gif

[!NOTE] See docs/*.py for the GIF generation. YOLOV8-like pipeline for bounding boxes and segmentation masks.

KerasAug aims to provide fast, robust and user-friendly preprocessing and augmentation layers, facilitating seamless integration with Keras 3 and tf.data.

The APIs largely follow torchvision, and the correctness of the layers has been verified through unit tests.

Also, you can check out the demo app on HF: App here: Open in HF Spaces

Why KerasAug

  • 🚀 Supports many preprocessing & augmentation layers across all backends (JAX, TensorFlow and Torch).
  • 🧰 Seamlessly integrates with tf.data, offering a performant and scalable data pipeline.
  • 🔥 Follows the same API design as torchvision.
  • 🙌 Depends only on Keras 3.

Installation

pip install keras keras-aug -U

[!IMPORTANT]
Make sure you have installed a supported backend for Keras.

Quickstart

Rock, Paper and Scissors Image Classification

Open In Colab

import keras
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_aug import layers as ka_layers

BATCH_SIZE = 64
NUM_CLASSES = 3
INPUT_SIZE = (128, 128)

# Create a `tf.data.Dataset`-compatible preprocessing pipeline.
# Note that this example works with all backends.
train_dataset, validation_dataset = tfds.load(
    "rock_paper_scissors", as_supervised=True, split=["train", "test"]
)
train_dataset = (
    train_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .shuffle(128)
    .map(ka_layers.vision.RandAugment())
    .map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES))
    .map(ka_layers.vision.Rescale(scale=2.0, offset=-1))  # [0, 1] to [-1, 1]
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)
validation_dataset = (
    validation_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .map(ka_layers.vision.Rescale(scale=2.0, offset=-1))  # [0, 1] to [-1, 1]
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)

# Create a model using MobileNetV2 as the backbone.
backbone = keras.applications.MobileNetV2(
    input_shape=(*INPUT_SIZE, 3), include_top=False
)
backbone.trainable = False
inputs = keras.Input((*INPUT_SIZE, 3))
x = backbone(inputs)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs, outputs)
model.summary()
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9),
    metrics=["accuracy"],
)

# Train and evaluate your model
model.fit(train_dataset, validation_data=validation_dataset, epochs=8)
model.evaluate(validation_dataset)

The above example runs with all backends (JAX, TensorFlow, Torch).

More Examples

Gradio App

gradio deploy

Citing KerasAug

@misc{chiu2023kerasaug,
  title={KerasAug},
  author={Hongyu, Chiu},
  year={2023},
  howpublished={\url{https://github.com/james77777778/keras-aug}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_aug-1.1.1.tar.gz (82.8 kB view hashes)

Uploaded Source

Built Distribution

keras_aug-1.1.1-py3-none-any.whl (146.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page