Skip to main content

A Keras-compatible generator for creating balanced batches

Project description

keras-balanced-batch-generator: A Keras-compatible generator for creating balanced batches

PyPI MIT license

Installation

pip install keras-balanced-batch-generator

Overview

This module implements an over-sampling algorithm to address the issue of class imbalance. It generates balanced batches, i.e., batches in which the number of samples from each class is on average the same. Generated batches are also shuffled.

The generator can be easily used with Keras models' fit method.

Currently, only NumPy arrays for single-input, single-output models are supported.

API

make_generator(x, y, batch_size,
               categorical=True,
               seed=None)
  • x (numpy.ndarray) Input data. Must have the same length as y.
  • y (numpy.ndarray) Target data. Must be a binary class matrix (i.e., shape (num_samples, num_classes)). You can use keras.utils.to_categorical to convert a class vector to a binary class matrix.
  • batch_size (int) Batch size.
  • categorical (bool) If true, generates binary class matrices (i.e., shape (num_samples, num_classes)) for batch targets. Otherwise, generates class vectors (i.e., shape (num_samples,)).
  • seed Random seed (see the docs).
  • Returns a Keras-compatible generator yielding batches as (x, y) tuples.

Usage

import keras
from keras_balanced_batch_generator import make_generator

x = ...
y = ...
batch_size = ...
steps_per_epoch = ...
model = keras.models.Sequential(...)

generator = make_generator(x, y, batch_size)
model.fit(generator, steps_per_epoch=steps_per_epoch)

Example: Multiclass Classification

import numpy as np
import keras
from keras_balanced_batch_generator import make_generator

num_samples = 100
num_classes = 3
input_shape = (2,)
batch_size = 16

x = np.random.rand(num_samples, *input_shape)
y = np.random.randint(low=0, high=num_classes, size=num_samples)
y = keras.utils.to_categorical(y)

generator = make_generator(x, y, batch_size)

model = keras.models.Sequential()
model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu'))
model.add(keras.layers.Dense(num_classes, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(generator, steps_per_epoch=10, epochs=5)

Example: Binary Classification

import numpy as np
import keras
from keras_balanced_batch_generator import make_generator

num_samples = 100
num_classes = 2
input_shape = (2,)
batch_size = 16

x = np.random.rand(num_samples, *input_shape)
y = np.random.randint(low=0, high=num_classes, size=num_samples)
y = keras.utils.to_categorical(y)

generator = make_generator(x, y, batch_size, categorical=False)

model = keras.models.Sequential()
model.add(keras.layers.Dense(32, input_shape=input_shape, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['binary_accuracy'])
model.fit(generator, steps_per_epoch=10, epochs=5)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras-balanced-batch-generator-0.0.3.tar.gz (4.1 kB view details)

Uploaded Source

Built Distribution

File details

Details for the file keras-balanced-batch-generator-0.0.3.tar.gz.

File metadata

File hashes

Hashes for keras-balanced-batch-generator-0.0.3.tar.gz
Algorithm Hash digest
SHA256 b074cec865b4afa2422a68368b84b7ec32bc7d7ba853564d873f92f94d8b3719
MD5 19ecb2f4985026041cfb211dcd14b1dc
BLAKE2b-256 8381fa6d65eec8b79d06c658548240eae447746d35911ee89297ca53af6b92e3

See more details on using hashes here.

File details

Details for the file keras_balanced_batch_generator-0.0.3-py3-none-any.whl.

File metadata

File hashes

Hashes for keras_balanced_batch_generator-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 451be1436210fae2cabed652714443031dca6089d8f7403916ae4077b36add35
MD5 28a0b5b7b9b9f8ee11b29360e1fea9b8
BLAKE2b-256 15511b9422c0eaffbcfdb1218d84ed7d19b0baf4812e311be77a8b7572ecb7b7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page