A library that includes Keras 3 preprocessing and augmentation layers
Project description
KerasAug
Description
KerasAug is a library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.
[!NOTE] See
docs/*.py
for the GIF generation. YOLOV8-like pipeline for bounding boxes and segmentation masks.
KerasAug aims to provide fast, robust and user-friendly preprocessing and augmentation layers, facilitating seamless integration with Keras 3 and tf.data
.
The APIs largely follow torchvision
, and the correctness of the layers has been verified through unit tests.
Also, you can check out the demo app on HF: App here:
Why KerasAug
- 🚀 Supports many preprocessing & augmentation layers across all backends (JAX, TensorFlow and Torch).
- 🧰 Seamlessly integrates with
tf.data
, offering a performant and scalable data pipeline. - 🔥 Follows the same API design as
torchvision
. - 🙌 Depends only on Keras 3.
Installation
pip install keras keras-aug -U
[!IMPORTANT]
Make sure you have installed a supported backend for Keras.
Quickstart
Rock, Paper and Scissors Image Classification
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
from keras_aug import layers as ka_layers
BATCH_SIZE = 64
NUM_CLASSES = 3
INPUT_SIZE = (128, 128)
# Create a `tf.data.Dataset`-compatible preprocessing pipeline.
# Note that this example works with all backends.
train_dataset, validation_dataset = tfds.load(
"rock_paper_scissors", as_supervised=True, split=["train", "test"]
)
train_dataset = (
train_dataset.batch(BATCH_SIZE)
.map(
lambda images, labels: {
"images": tf.cast(images, "float32") / 255.0,
"labels": tf.one_hot(labels, NUM_CLASSES),
}
)
.map(ka_layers.vision.Resize(INPUT_SIZE))
.shuffle(128)
.map(ka_layers.vision.RandAugment())
.map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES))
.map(ka_layers.vision.Rescale(scale=2.0, offset=-1)) # [0, 1] to [-1, 1]
.map(lambda data: (data["images"], data["labels"]))
.prefetch(tf.data.AUTOTUNE)
)
validation_dataset = (
validation_dataset.batch(BATCH_SIZE)
.map(
lambda images, labels: {
"images": tf.cast(images, "float32") / 255.0,
"labels": tf.one_hot(labels, NUM_CLASSES),
}
)
.map(ka_layers.vision.Resize(INPUT_SIZE))
.map(ka_layers.vision.Rescale(scale=2.0, offset=-1)) # [0, 1] to [-1, 1]
.map(lambda data: (data["images"], data["labels"]))
.prefetch(tf.data.AUTOTUNE)
)
# Create a model using MobileNetV2 as the backbone.
backbone = keras.applications.MobileNetV2(
input_shape=(*INPUT_SIZE, 3), include_top=False
)
backbone.trainable = False
inputs = keras.Input((*INPUT_SIZE, 3))
x = backbone(inputs)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs, outputs)
model.summary()
model.compile(
loss="categorical_crossentropy",
optimizer=keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9),
metrics=["accuracy"],
)
# Train and evaluate your model
model.fit(train_dataset, validation_data=validation_dataset, epochs=8)
model.evaluate(validation_dataset)
The above example runs with all backends (JAX, TensorFlow, Torch).
More Examples
Gradio App
gradio deploy
Citing KerasAug
@misc{chiu2023kerasaug,
title={KerasAug},
author={Hongyu, Chiu},
year={2023},
howpublished={\url{https://github.com/james77777778/keras-aug}},
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file keras_aug-1.1.1.tar.gz
.
File metadata
- Download URL: keras_aug-1.1.1.tar.gz
- Upload date:
- Size: 82.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8b9c22c288498c50c17f740675e51fe7cd07f9dc5c24c5f0c79c588007d18bf1 |
|
MD5 | 6bfcdbff4f1975d2255d65dc08ce5a10 |
|
BLAKE2b-256 | 8268f61fe767218fcfa4f4a180c311416489162c81acce626b52b0220d2b63cf |
File details
Details for the file keras_aug-1.1.1-py3-none-any.whl
.
File metadata
- Download URL: keras_aug-1.1.1-py3-none-any.whl
- Upload date:
- Size: 146.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74d0925c9846e8b7b9b8b94aec246fbb0a4ce7ddadfa1bc5831bc0f2c83dce63 |
|
MD5 | e4a567c6b50a86e7d2af8abb21f6052f |
|
BLAKE2b-256 | 1f6dfc4a1efea49d4d6ae30de6bc0690ac04a1aa30ff8bf3c0a0c995ce48b163 |