Skip to main content

Cross-validation for keras models

Project description

Keras Cross-validation

keras-model-cv allows you to cross-validate keras model.

Installation

pip install keras-model-cv

or

pip install git+https://github.com/dubovikmaster/keras-model-cv.git

Quickstart

from keras_model_cv import KerasCV
from sklearn.model_selection import KFold
import tensorflow as tf

tf.get_logger().setLevel("INFO")

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()


def build_model(hidden_units, dropout):
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(hidden_units, activation="relu"),
            tf.keras.layers.Dropout(dropout),
            tf.keras.layers.Dense(10),
        ]
    )
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    return model


PARAMS = {'hidden_units': 16, 'dropout': .3}

if __name__ == '__main__':
    cv = KerasCV(
        build_model,
        KFold(n_splits=3, random_state=1234, shuffle=True),
        PARAMS,
        preprocessor=tf.keras.layers.Normalization(),
        save_history=True,
        directory='my_awesome_project',
        name='my_cv',
    )
    cv.fit(x_train, y_train, verbose=0, epochs=3)
    print(cv.get_cv_score())
Out: 
                loss  accuracy
        mean  0.283194  0.919783
        std   0.004215  0.002887 

You can add another aggregate function (for more info see: pandas.DataFrame.agg):

print(cv.get_cv_score(agg_func={'loss': min, 'accuracy': max}))
Out:
        loss        0.27959
        accuracy    0.92010

Also, you can get all train history for each splits as pandas dataframe:

cv.get_train_history()
Out:
             loss    accuracy    split  epochs
        0  0.957261  0.679375      0       1
        1  0.595646  0.809850      0       2
        2  0.541124  0.824850      0       3
        3  0.835493  0.722475      1       1
        4  0.574581  0.810925      1       2
        5  0.526098  0.829200      1       3
        6  0.813172  0.736200      2       1
        7  0.556871  0.816875      2       2
        8  0.512916  0.829550      2       3

You can show train history as matplotlib plot:

cv.show_train_history()

What about metrics per splits?

cv.get_split_scores()
Out:
            accuracy   loss     split
        0    0.9201  0.282442      0
        1    0.9198  0.290500      1
        2    0.9173  0.279590      2

If save_history=True train history, validation metrics and info about split will be saved to the specified directory. In our example:

my_awesome_project/
 |--my_cv/
      |--split_0/
           |--history.yml
           |--validation_metric.yml
           |--split_info.yml
           
      |--split_1/
      |--split_2/

Licence

MIT license

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

keras_model_cv-0.3.1.tar.gz (5.6 kB view details)

Uploaded Source

Built Distribution

keras_model_cv-0.3.1-py3-none-any.whl (5.8 kB view details)

Uploaded Python 3

File details

Details for the file keras_model_cv-0.3.1.tar.gz.

File metadata

  • Download URL: keras_model_cv-0.3.1.tar.gz
  • Upload date:
  • Size: 5.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for keras_model_cv-0.3.1.tar.gz
Algorithm Hash digest
SHA256 437da85c9b1b502ad06da390ebd814fcd56c6011438b855190244d9d6a60db61
MD5 7253f12735b25d901ebe46cb33f542a5
BLAKE2b-256 b8e030e03f47171c99d58c3bb87366dccf9ed8b6e90976ee4ad218624bdda5c1

See more details on using hashes here.

File details

Details for the file keras_model_cv-0.3.1-py3-none-any.whl.

File metadata

File hashes

Hashes for keras_model_cv-0.3.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a15fe0674cfe8381766588893d586b4540eb49a4529a365ae3caa5c74c64f5cb
MD5 a6c3f811a97d72824bc90cbe49afa529
BLAKE2b-256 ea854ed976c08a13c85e16f5c8c83ab30fbd1f98870a7df99074209466721eca

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page