A Keras-compatible generator for creating balanced batches
Project description
keras-balanced-batch-generator: A Keras-compatible generator for creating balanced batches
Installation
pip install keras-balanced-batch-generator
Overview
This module implements an over-sampling algorithm to address the issue of class imbalance. It generates balanced batches, i.e., batches in which the number of samples from each class is on average the same. Generated batches are also shuffled.
The generator can be easily used with Keras models'
fit method.
Currently, only NumPy arrays for single-input, single-output models are supported.
API
make_generator(x, y, batch_size,
categorical=True,
seed=None)
x(numpy.ndarray) Input data. Must have the same length asy.y(numpy.ndarray) Target data. Must be a binary class matrix (i.e., shape(num_samples, num_classes)). You can usekeras.utils.to_categoricalto convert a class vector to a binary class matrix.batch_size(int) Batch size.categorical(bool) If true, generates binary class matrices (i.e., shape(num_samples, num_classes)) for batch targets. Otherwise, generates class vectors (i.e., shape(num_samples,)).seedRandom seed (see the docs).- Returns a Keras-compatible generator yielding batches as
(x, y)tuples.
Usage
import keras
from keras_balanced_batch_generator import make_generator
x = ...
y = ...
batch_size = ...
steps_per_epoch = ...
model = keras.models.Sequential(...)
generator = make_generator(x, y, batch_size)
model.fit(generator, steps_per_epoch=steps_per_epoch)
Example: Multiclass Classification
import numpy as np
import keras
from keras_balanced_batch_generator import make_generator
num_samples = 100
num_classes = 3
input_shape = (2,)
batch_size = 16
x = np.random.rand(num_samples, *input_shape)
y = np.random.randint(low=0, high=num_classes, size=num_samples)
y = keras.utils.to_categorical(y)
generator = make_generator(x, y, batch_size)
model = keras.models.Sequential()
model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu'))
model.add(keras.layers.Dense(num_classes, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(generator, steps_per_epoch=10, epochs=5)
Example: Binary Classification
import numpy as np
import keras
from keras_balanced_batch_generator import make_generator
num_samples = 100
num_classes = 2
input_shape = (2,)
batch_size = 16
x = np.random.rand(num_samples, *input_shape)
y = np.random.randint(low=0, high=num_classes, size=num_samples)
y = keras.utils.to_categorical(y)
generator = make_generator(x, y, batch_size, categorical=False)
model = keras.models.Sequential()
model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy'])
model.fit(generator, steps_per_epoch=10, epochs=5)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keras-balanced-batch-generator-0.0.3.tar.gz.
File metadata
- Download URL: keras-balanced-batch-generator-0.0.3.tar.gz
- Upload date:
- Size: 4.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b074cec865b4afa2422a68368b84b7ec32bc7d7ba853564d873f92f94d8b3719
|
|
| MD5 |
19ecb2f4985026041cfb211dcd14b1dc
|
|
| BLAKE2b-256 |
8381fa6d65eec8b79d06c658548240eae447746d35911ee89297ca53af6b92e3
|
File details
Details for the file keras_balanced_batch_generator-0.0.3-py3-none-any.whl.
File metadata
- Download URL: keras_balanced_batch_generator-0.0.3-py3-none-any.whl
- Upload date:
- Size: 4.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
451be1436210fae2cabed652714443031dca6089d8f7403916ae4077b36add35
|
|
| MD5 |
28a0b5b7b9b9f8ee11b29360e1fea9b8
|
|
| BLAKE2b-256 |
15511b9422c0eaffbcfdb1218d84ed7d19b0baf4812e311be77a8b7572ecb7b7
|