Skip to main content

Cross-validation for keras models

Project description

Keras Cross-validation

keras-model-cv allows you to cross-validate keras model.

Installation

pip install keras-model-cv

or

pip install git+https://github.com/dubovikmaster/keras-model-cv.git

Quickstart

from keras_model_cv import KerasCV
from sklearn.model_selection import KFold
import tensorflow as tf

tf.get_logger().setLevel("INFO")

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()


def build_model(hidden_units, dropout):
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(hidden_units, activation="relu"),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(10),
        ]
    )
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    return model


PARAMS = {'hidden_units': 16, 'dropout': .3}

if __name__ == '__main__':
    cv = KerasCV(
        build_model,
        KFold(n_splits=3, random_state=1234, shuffle=True),
        PARAMS,
        preprocessor=tf.keras.layers.Normalization(),
        save_history=True,
        directory='my_awesome_project',
        name='my_cv',
    )
    cv.fit(x_train, y_train, verbose=0, epochs=3)
    print(cv.get_cv_score())
Out: 
                loss  accuracy
        mean  0.283194  0.919783
        std   0.004215  0.002887 

You can add another aggregate function (for more info see: pandas.DataFrame.agg):

print(cv.get_cv_score(agg_func={'loss': min, 'accuracy': max}))
Out:
        loss        0.27959
        accuracy    0.92010

Also, you can get all train history for each splits as pandas dataframe:

cv.get_train_history()
Out:
             loss    accuracy    split  epochs
        0  0.957261  0.679375      0       1
        1  0.595646  0.809850      0       2
        2  0.541124  0.824850      0       3
        3  0.835493  0.722475      1       1
        4  0.574581  0.810925      1       2
        5  0.526098  0.829200      1       3
        6  0.813172  0.736200      2       1
        7  0.556871  0.816875      2       2
        8  0.512916  0.829550      2       3

You can show train history as matplotlib plot:

cv.show_train_history()

What about metrics per splits?

cv.get_split_scores()
Out:
            accuracy   loss     split
        0    0.9201  0.282442      0
        1    0.9198  0.290500      1
        2    0.9173  0.279590      2

If save_history=True train history, validation metrics and info about split will be saved to the specified directory. In our example:

my_awesome_project/
 |--my_cv/
      |--split_0/
           |--history.yml
           |--validation_metric.yml
           |--split_info.yml
           
      |--split_1/
      |--split_2/

Licence

MIT license

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_model_cv-0.2.1.tar.gz (5.4 kB view details)

Uploaded Source

Built Distribution

keras_model_cv-0.2.1-py3-none-any.whl (5.7 kB view details)

Uploaded Python 3

File details

Details for the file keras_model_cv-0.2.1.tar.gz.

File metadata

  • Download URL: keras_model_cv-0.2.1.tar.gz
  • Upload date:
  • Size: 5.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for keras_model_cv-0.2.1.tar.gz
Algorithm Hash digest
SHA256 07a9b807bf3854a0ac57045367db87525e022fbb988bce63ff988cce721b712e
MD5 886c5e5e89043bc29320e10832468be4
BLAKE2b-256 402e3bdd844d302b12be278c95faae350f9e63b02384e3289fd7da200f5ddcf9

See more details on using hashes here.

File details

Details for the file keras_model_cv-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for keras_model_cv-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0ea4497730b5f15b175f8d3ee9a57301df87dc9a41652ca257bd376c95961c4a
MD5 1a96247f908ac75d7e5a521be3b772f0
BLAKE2b-256 8e362bf29aae8820c65223e32af500e9fe291a7660e0905e40649b7c0dbad3e3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page