Cross-validation for keras models
Project description
Keras Cross-validation
keras-model-cv
allows you to cross-validate keras
model.
Installation
pip install keras-model-cv
or
pip install git+https://github.com/dubovikmaster/keras-model-cv.git
Quickstart
from keras_model_cv import KerasCV
from sklearn.model_selection import KFold
import tensorflow as tf
tf.get_logger().setLevel("INFO")
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
def build_model(hidden_units, dropout):
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(hidden_units, activation="relu"),
tf.keras.layers.Dropout(dropout),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
return model
PARAMS = {'hidden_units': 16, 'dropout': .3}
if __name__ == '__main__':
cv = KerasCV(
build_model,
KFold(n_splits=3, random_state=1234, shuffle=True),
PARAMS,
preprocessor=tf.keras.layers.Normalization(),
save_history=True,
directory='my_awesome_project',
name='my_cv',
)
cv.fit(x_train, y_train, verbose=0, epochs=3)
print(cv.get_cv_score())
Out:
loss accuracy
mean 0.283194 0.919783
std 0.004215 0.002887
You can add another aggregate function (for more info see: pandas.DataFrame.agg):
print(cv.get_cv_score(agg_func={'loss': min, 'accuracy': max}))
Out:
loss 0.27959
accuracy 0.92010
Also, you can get all train history for each splits as pandas
dataframe:
cv.get_train_history()
Out:
loss accuracy split epochs
0 0.957261 0.679375 0 1
1 0.595646 0.809850 0 2
2 0.541124 0.824850 0 3
3 0.835493 0.722475 1 1
4 0.574581 0.810925 1 2
5 0.526098 0.829200 1 3
6 0.813172 0.736200 2 1
7 0.556871 0.816875 2 2
8 0.512916 0.829550 2 3
You can show train history as matplotlib plot:
cv.show_train_history()
What about metrics per splits?
cv.get_split_scores()
Out:
accuracy loss split
0 0.9201 0.282442 0
1 0.9198 0.290500 1
2 0.9173 0.279590 2
If save_history=True
train history, validation metrics and info about split will be saved to the specified directory.
In our example:
my_awesome_project/
|--my_cv/
|--split_0/
|--history.yml
|--validation_metric.yml
|--split_info.yml
|--split_1/
|--split_2/
Licence
MIT license
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
keras_model_cv-0.5.4.tar.gz
(5.9 kB
view details)
Built Distribution
File details
Details for the file keras_model_cv-0.5.4.tar.gz
.
File metadata
- Download URL: keras_model_cv-0.5.4.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 23983835c43f2af35eda7a308238b55e77000505678836201e0c80628110b37f |
|
MD5 | f4c44dbe311834da3add952519d3aada |
|
BLAKE2b-256 | b470a29dba9d33bd949a0ae5249d2fb8f498b9e20a4b665987d1204d0fd04688 |
File details
Details for the file keras_model_cv-0.5.4-py3-none-any.whl
.
File metadata
- Download URL: keras_model_cv-0.5.4-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 512c89101f1999b449f0bcb5655c7ef9e8eab7708ca7f3ce08ce62a9605f65e4 |
|
MD5 | f48a287cc52a9f442ce8bea84ab370c6 |
|
BLAKE2b-256 | 12d59d06e9b433c953c7da7891fb007de86b4b92f16e1747f25674b1af4bf9bd |