Skip to main content

Common used code in Keras

Project description

simplified keras

Common used actions in keras

Table of contents

General info

This package is a set of common used actions in keras. At this moment includes (may be outdated):

Libraries

  • Matplotlib - version 3.3.3
  • NumPy - version 1.19.4
  • Tensorflow - version 2.4.1
  • Pandas - version 1.1.5
  • Seaborn - version 0.11.1

Setup

  • Install from PyPi: pip install simplified-keras

Documentation

Main package

Generators

default generators
from keras.preprocessing.image import ImageDataGenerator
from simplified_keras.generators import get_train_val_generators, get_val_test_generators

img_size = (48, 48)
img_datagen = ImageDataGenerator(rescale=1/255)

# same for get_val_test_generators
# default: data_dir='../data', color_mode='rgb', batch_size=128, class_mode='categorical'
train_generator, validation_generator = get_train_val_generators(img_datagen, data_dir='../data/normal',
                                                                 color_mode='grayscale', target_size=img_size)
val_generator, test_generator = get_val_test_generators(img_datagen, batch_size=32)
numpy_memmap_generator
from simplified_keras.generators import numpy_memmap_generator

# default batch_size=128, shuffle_array=True
train_gen = numpy_memmap_generator('imgs.npy', 'labels.npy', batch_size=64, shuffle_array=False)

Default callbacks

from simplified_keras.callbacks import get_default_callbacks

callbacks = get_default_callbacks('models/vgg16_calssifier.h5', monitor='val_loss', verbose=0)

hist = model.fit(train_generator, steps_per_epoch=train_steps, validation_data=validation_generator,
                 validation_steps=valid_steps, epochs=100, callbacks=callbacks, verbose=2)

Signature:

def get_default_callbacks(model_path, monitor='val_acc', base_patience=3, lr_reduce_factor=0.5, min_lr=1e-7, verbose=1):
    return [
        ReduceLROnPlateau(monitor=monitor, factor=lr_reduce_factor, min_lr=min_lr, patience=base_patience, verbose=verbose),
        EarlyStopping(monitor=monitor, patience=(2 * base_patience + 1), verbose=verbose),
        ModelCheckpoint(monitor=monitor, filepath=model_path, save_best_only=True, verbose=verbose)
    ]

Restore callbacks

Used to restore callback after paused learning. Model should come from last checkpoint.

from simplified_keras.callbacks import restore_callbacks, get_default_callbacks

callbacks = get_default_callbacks('models/vgg16_calssifier.h5', monitor='val_loss', verbose=0)
acc, loss = model.evaluate(val_ds)

# acc or loss depending on the compiled model metrics
restore_callbacks(callbacks, acc)

Plots

Accuracy and Loss plot

from simplified_keras.plots.history_plots import plot_acc_and_loss

history = model.fit(train_gen, teps_per_epoch=train_steps, epochs=5, validation_data=val_gen, 
                    validation_steps=val_steps, callbacks=callbacks)

fig = plot_acc_and_loss(history)

Result:

history.png

Predictions with image plot

from keras.models import load_model
from keras.preprocessing.image import ImageDataGenerator
from simplified_keras.generators import get_train_val_generators
from simplified_keras.plots import plot_predictions_with_img

img_size = (48, 48)
img_datagen = ImageDataGenerator(rescale=1 / 255)

_, validation_generator = get_train_val_generators(img_datagen, data_dir='../data/normal',
                                                   color_mode='grayscale', target_size=img_size)
model = load_model('../models/standard_model.h5')

batch, labels = validation_generator.next()
preds = model.predict(batch)

named_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
fig = plot_predictions_with_img(1, preds, labels, batch, named_labels, grayscale=True)

Result:

p.png

Histogram with CDF and image plot

import cv2
from simplified_keras.plots import plot_gray_img_with_histogram

img = cv2.imread(f'{src_train_path}/0/241.png', 0)
fig1 = plot_gray_img_with_histogram(img)
img2 = stretch_histogram(img)
fig2 = plot_gray_img_with_histogram(img2)

Result:

history1.png history2.png

Confusion matrix plot

from simplified_keras.transformations import predictions_to_classes, one_hot_to_sparse
from simplified_keras.metrics import get_confusion_matrixes
from simplified_keras.plots import plot_confusion_matrix

predictions = model.predict(validation_images)
predicted_classes = predictions_to_classes(predictions)
sparse_labels = one_hot_to_sparse(validation_labels)

cm, cm_normalized = get_confusion_matrixes(predicted_classes, sparse_labels)
classes = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
f1 = plot_confusion_matrix(cm, classes)
f2 = plot_confusion_matrix(cm_normalized, classes, figsize=(10, 8))

Results:

Metrics

Confusion matrix

from simplified_keras.transformations import predictions_to_classes, one_hot_to_sparse
from simplified_keras.metrics import get_confusion_matrixes

predictions = model.predict(validation_images)
predicted_classes = predictions_to_classes(predictions)
sparse_labels = one_hot_to_sparse(validation_labels)

# Returns two numpy arrays: standard and normalized
cm, cm_normalized = get_confusion_matrixes(predicted_classes, sparse_labels)

Model Statistics

Calculates:

  • FP
  • FN
  • TP
  • TN
  • TPR # Sensitivity/true positive rate
  • TNR # Specificity/true negative rate
  • PPV # Precision/positive predictive value
  • NPV # Negative predictive value
  • FPR # Fall out or false positive rate
  • FNR # False negative rate
  • FDR # False discovery rate
  • ACC # Overall accuracy for each class
  • Much more and still increasing
from simplified_keras.transformations import predictions_to_classes, one_hot_to_sparse
from simplified_keras.metrics import get_confusion_matrixes
from simplified_keras.metrics import get_model_statistics

predictions = model.predict(validation_images)
predicted_classes = predictions_to_classes(predictions)
sparse_labels = one_hot_to_sparse(validation_labels)

cm, cm_normalized = get_confusion_matrixes(predicted_classes, sparse_labels)

stats = get_model_statistics(cm)
classes = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
fig = stats.visualize(classes)
print(stats.TN) #[2890 3530 2874 2361 2591 3000 2661]

Visualization: stat-visualization.png

Folder statistics

from simplified_keras.metrics import get_folders_statistics

stat = get_folders_statistics('../data/normal/train')
print(stat.nr_of_elements, stat.info) # 28709 {'0': 3995, '1': 436, '2': 4097, '3': 7215, '4': 4830, '5': 3171, '6': 4965}
fig = stat.bar_plot()

Result:

drawing

Model memory usage

from simplified_keras.metrics import get_model_memory_usage

batch_size = 64
# outputs usage in GB
usage = get_model_memory_usage(batch_size, model)
print(usage, 'GB') # 8.34 GB

Transformations

Convert predictions to classes array

from simplified_keras.transformations import predictions_to_classes

predictions = model.predict(validation_images)
predicted_classes = predictions_to_classes(predictions)
print(pedicted_classes) #[6 3 3 ... 6 2 0]

Convert one hot encoding to sparse

from simplified_keras.transformations import one_hot_to_sparse

sparse_labels = one_hot_to_sparse(validation_labels)
print(sprase_labels) #[6 6 6 ... 6 2 0]

Stretch histogram

from simplified_keras.transformations import stretch_histogram

# default color_mode='rgb'
stretch_histogram(image, color_mode='grayscale')

Unfreeze model

from simplified_keras.transformations import unfreeze_model
from tensorflow.keras.optimizers import Adam

# default params: optimizer=Adam(learning_rate=1e-5), metrics="acc"
unfreeze_model(model, optimizer=Adam(learning_rate=1e-4), metrics="loss")

Normalize histogram clahe

from simplified_keras.transformations import normalize_histogram_clahe

# default clip_limit=2.0, tile_grid_size=(8, 8), color_mode='rgb'
normalize_histogram_clahe(image)

Replace activations

Replaces all activation functions in given model

from simplified_keras.transformations import replace_activations
from tensorflow.keras.layers import LeakyReLU

l_relu = LeakyReLU()
replace_activations(model, l_relu)

Layers

Augumentation layers build on tensorflow image operations

from simplified_keras.transformations import RandomSaturation, RandomHue, RandomBrightness
from tensorflow.keras import Sequential

# for more informations about parameters see tf.image docs
augument_layers = Sequential([
  RandomSaturation(0.5, 1.5),
  # must be [0 - 0,5]
  RandomHue(0.2),
  RandomBrightness(0.2)
])

PyPi

simplified-keras

TODO

  • Add tests and a lot of features :)

Development

Want to contribute? Great!

To fix a bug or enhance an existing module, follow these steps:

  • Fork the repo
  • Create a new branch (git checkout -b improve-feature)
  • Make the appropriate changes in the files
  • Verify if they are correct
  • Add changes to reflect the changes made
  • Commit changes
  • Push to the branch (git push origin improve-feature)
  • Create a Pull Request

Status

Library is: in progress

Contact

albert.lis.1996@gmail.com - feel free to contact me!

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

simplified_keras-0.0.22-py3-none-any.whl (16.8 kB view details)

Uploaded Python 3

File details

Details for the file simplified_keras-0.0.22-py3-none-any.whl.

File metadata

  • Download URL: simplified_keras-0.0.22-py3-none-any.whl
  • Upload date:
  • Size: 16.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/52.0.0.post20210125 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.8.3

File hashes

Hashes for simplified_keras-0.0.22-py3-none-any.whl
Algorithm Hash digest
SHA256 d2ca3e14bed7a7e40fd9a5fb22d7e358a6febc127010947b05d3bc86b61706a7
MD5 1835b2d0b4a8b82452500bc52a6e6b15
BLAKE2b-256 be5b06f8500dc4fa3f095c2a4a60a2e260cbfd727bd6a70619c3cb16153001ac

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page