Skip to main content

Cross-validation for keras models

Project description

Keras Cross-validation

keras-model-cv allows you to cross-validate keras model.

Installation

pip install keras-model-cv

or

pip install git+https://github.com/dubovikmaster/keras-model-cv.git

Quickstart

from keras_model_cv import KerasCV
from sklearn.model_selection import KFold
import tensorflow as tf

tf.get_logger().setLevel("INFO")

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()


def build_model(hidden_units, dropout):
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(hidden_units, activation="relu"),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(10),
        ]
    )
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    return model


PARAMS = {'hidden_units': 16, 'dropout': .3}

if __name__ == '__main__':
    cv = KerasCV(
        build_model,
        KFold(n_splits=3, random_state=1234, shuffle=True),
        PARAMS,
        preprocessor=tf.keras.layers.Normalization(),
        save_history=True,
        directory='my_awesome_project',
        name='my_cv',
    )
    cv.fit(x_train, y_train, verbose=0, epochs=3)
    print(cv.get_cv_score())
Out: 
                loss  accuracy
        mean  0.283194  0.919783
        std   0.004215  0.002887 

You can add another aggregate function (for more info see: pandas.DataFrame.agg):

print(cv.get_cv_score(agg_func={'loss': min, 'accuracy': max}))
Out:
        loss        0.27959
        accuracy    0.92010

Also, you can get all train history for each splits as pandas dataframe:

cv.get_train_history()
Out:
             loss    accuracy    split  epochs
        0  0.957261  0.679375      0       1
        1  0.595646  0.809850      0       2
        2  0.541124  0.824850      0       3
        3  0.835493  0.722475      1       1
        4  0.574581  0.810925      1       2
        5  0.526098  0.829200      1       3
        6  0.813172  0.736200      2       1
        7  0.556871  0.816875      2       2
        8  0.512916  0.829550      2       3

You can show train history as matplotlib plot:

cv.show_train_history()

What about metrics per splits?

cv.get_split_scores()
Out:
            accuracy   loss     split
        0    0.9201  0.282442      0
        1    0.9198  0.290500      1
        2    0.9173  0.279590      2

If save_history=True train history, validation metrics and info about split will be saved to the specified directory. In our example:

my_awesome_project/
 |--my_cv/
      |--split_0/
           |--history.yml
           |--validation_metric.yml
           |--split_info.yml
           
      |--split_1/
      |--split_2/

Licence

MIT license

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_model_cv-0.5.4.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

keras_model_cv-0.5.4-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file keras_model_cv-0.5.4.tar.gz.

File metadata

  • Download URL: keras_model_cv-0.5.4.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for keras_model_cv-0.5.4.tar.gz
Algorithm Hash digest
SHA256 23983835c43f2af35eda7a308238b55e77000505678836201e0c80628110b37f
MD5 f4c44dbe311834da3add952519d3aada
BLAKE2b-256 b470a29dba9d33bd949a0ae5249d2fb8f498b9e20a4b665987d1204d0fd04688

See more details on using hashes here.

File details

Details for the file keras_model_cv-0.5.4-py3-none-any.whl.

File metadata

File hashes

Hashes for keras_model_cv-0.5.4-py3-none-any.whl
Algorithm Hash digest
SHA256 512c89101f1999b449f0bcb5655c7ef9e8eab7708ca7f3ce08ce62a9605f65e4
MD5 f48a287cc52a9f442ce8bea84ab370c6
BLAKE2b-256 12d59d06e9b433c953c7da7891fb007de86b4b92f16e1747f25674b1af4bf9bd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page