A library that includes Keras 3 preprocessing and augmentation layers
Project description
KerasAug
Description
KerasAug is a library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.
[!NOTE] See
docs/*.py
for the GIF generation. YOLOV8-like pipeline for bounding boxes and segmentation masks.
KerasAug aims to provide fast, robust and user-friendly preprocessing and augmentation layers, facilitating seamless integration with Keras 3 and tf.data.Dataset
.
The APIs largely follow torchvision
, and the correctness of the layers has been verified through unit tests.
Also, you can check out the demo app on HF:
Installation
pip install keras keras-aug -U
[!IMPORTANT]
Make sure you have installed a supported backend for Keras.
Quickstart
Rock, Paper and Scissors Image Classification
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
from keras_aug import layers as ka_layers
BATCH_SIZE = 64
NUM_CLASSES = 3
INPUT_SIZE = (128, 128)
# Create a `tf.data.Dataset`-compatible preprocessing pipeline with all backends
train_dataset, validation_dataset = tfds.load(
"rock_paper_scissors", as_supervised=True, split=["train", "test"]
)
train_dataset = (
train_dataset.batch(BATCH_SIZE)
.map(
lambda images, labels: {
"images": tf.cast(images, "float32") / 255.0,
"labels": tf.one_hot(labels, NUM_CLASSES),
}
)
.map(ka_layers.vision.Resize(INPUT_SIZE))
.shuffle(128)
.map(ka_layers.vision.RandAugment())
.map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES))
.map(lambda data: (data["images"], data["labels"]))
.prefetch(tf.data.AUTOTUNE)
)
validation_dataset = (
validation_dataset.batch(BATCH_SIZE)
.map(
lambda images, labels: {
"images": tf.cast(images, "float32") / 255.0,
"labels": tf.one_hot(labels, NUM_CLASSES),
}
)
.map(ka_layers.vision.Resize(INPUT_SIZE))
.map(lambda data: (data["images"], data["labels"]))
.prefetch(tf.data.AUTOTUNE)
)
# Create a CNN model
model = keras.models.Sequential(
[
keras.Input((*INPUT_SIZE, 3)),
keras.layers.Conv2D(32, (3, 3), activation="relu"),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(64, (3, 3), activation="relu"),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(128, (3, 3), activation="relu"),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(256, (3, 3), activation="relu"),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Flatten(),
keras.layers.Dense(512, activation="relu"),
keras.layers.Dense(NUM_CLASSES, activation="softmax"),
]
)
model.summary()
model.compile(
loss="categorical_crossentropy",
optimizer=keras.optimizers.AdamW(),
metrics=["accuracy"],
)
# Train your model
model.fit(train_dataset, validation_data=validation_dataset, epochs=8)
The above example runs with all backends (JAX, TensorFlow, Torch).
More Examples
Gradio App
gradio deploy
Citing KerasAug
@misc{chiu2023kerasaug,
title={KerasAug},
author={Hongyu, Chiu},
year={2023},
howpublished={\url{https://github.com/james77777778/keras-aug}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for keras_aug-1.0.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 02eaf439026f41a20c37c564201e3d2104a6af7390630210f4e6c57451310bab |
|
MD5 | ef9a97763a2fce7befa043e75cadff52 |
|
BLAKE2b-256 | 8ad8fad90468ee81bbdc1503e4b4e4a07ab39da73923d3bb432efc3eb543ea41 |