Skip to main content

A library that includes Keras 3 preprocessing and augmentation layers

Project description

KerasAug

Keras GitHub Workflow Status codecov PyPI PyPI - Downloads Open in HF Spaces

Description

KerasAug is a library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.

object_detection.gif semantic_segmentation.gif

[!NOTE] See docs/*.py for the GIF generation. YOLOV8-like pipeline for bounding boxes and segmentation masks.

KerasAug aims to provide fast, robust and user-friendly preprocessing and augmentation layers, facilitating seamless integration with Keras 3 and tf.data.

The APIs largely follow torchvision, and the correctness of the layers has been verified through unit tests.

Also, you can check out the demo app on HF: App here: Open in HF Spaces

Why KerasAug

  • 🚀 Supports many preprocessing & augmentation layers across all backends (JAX, TensorFlow and Torch).
  • 🧰 Seamlessly integrates with tf.data, offering a performant and scalable data pipeline.
  • 🔥 Follows the same API design as torchvision.
  • 🙌 Depends only on Keras 3.

Installation

pip install keras keras-aug -U

[!IMPORTANT]
Make sure you have installed a supported backend for Keras.

Quickstart

Rock, Paper and Scissors Image Classification

Open In Colab

import keras
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_aug import layers as ka_layers

BATCH_SIZE = 64
NUM_CLASSES = 3
INPUT_SIZE = (128, 128)

# Create a `tf.data.Dataset`-compatible preprocessing pipeline.
# Note that this example works with all backends.
train_dataset, validation_dataset = tfds.load(
    "rock_paper_scissors", as_supervised=True, split=["train", "test"]
)
train_dataset = (
    train_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .shuffle(128)
    .map(ka_layers.vision.RandAugment())
    .map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES))
    .map(ka_layers.vision.Rescale(scale=2.0, offset=-1))  # [0, 1] to [-1, 1]
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)
validation_dataset = (
    validation_dataset.batch(BATCH_SIZE)
    .map(
        lambda images, labels: {
            "images": tf.cast(images, "float32") / 255.0,
            "labels": tf.one_hot(labels, NUM_CLASSES),
        }
    )
    .map(ka_layers.vision.Resize(INPUT_SIZE))
    .map(ka_layers.vision.Rescale(scale=2.0, offset=-1))  # [0, 1] to [-1, 1]
    .map(lambda data: (data["images"], data["labels"]))
    .prefetch(tf.data.AUTOTUNE)
)

# Create a model using MobileNetV2 as the backbone.
backbone = keras.applications.MobileNetV2(
    input_shape=(*INPUT_SIZE, 3), include_top=False
)
backbone.trainable = False
inputs = keras.Input((*INPUT_SIZE, 3))
x = backbone(inputs)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs, outputs)
model.summary()
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9),
    metrics=["accuracy"],
)

# Train and evaluate your model
model.fit(train_dataset, validation_data=validation_dataset, epochs=8)
model.evaluate(validation_dataset)

The above example runs with all backends (JAX, TensorFlow, Torch).

More Examples

Gradio App

gradio deploy

Citing KerasAug

@misc{chiu2023kerasaug,
  title={KerasAug},
  author={Hongyu, Chiu},
  year={2023},
  howpublished={\url{https://github.com/james77777778/keras-aug}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_aug-1.1.1.tar.gz (82.8 kB view details)

Uploaded Source

Built Distribution

keras_aug-1.1.1-py3-none-any.whl (146.1 kB view details)

Uploaded Python 3

File details

Details for the file keras_aug-1.1.1.tar.gz.

File metadata

  • Download URL: keras_aug-1.1.1.tar.gz
  • Upload date:
  • Size: 82.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for keras_aug-1.1.1.tar.gz
Algorithm Hash digest
SHA256 8b9c22c288498c50c17f740675e51fe7cd07f9dc5c24c5f0c79c588007d18bf1
MD5 6bfcdbff4f1975d2255d65dc08ce5a10
BLAKE2b-256 8268f61fe767218fcfa4f4a180c311416489162c81acce626b52b0220d2b63cf

See more details on using hashes here.

File details

Details for the file keras_aug-1.1.1-py3-none-any.whl.

File metadata

  • Download URL: keras_aug-1.1.1-py3-none-any.whl
  • Upload date:
  • Size: 146.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.5

File hashes

Hashes for keras_aug-1.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 74d0925c9846e8b7b9b8b94aec246fbb0a4ce7ddadfa1bc5831bc0f2c83dce63
MD5 e4a567c6b50a86e7d2af8abb21f6052f
BLAKE2b-256 1f6dfc4a1efea49d4d6ae30de6bc0690ac04a1aa30ff8bf3c0a0c995ce48b163

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page