Cross-validation for keras models
Project description
Keras Cross-validation
keras-model-cv allows you to cross-validate keras model.
Installation
pip install keras-model-cv
or
pip install git+https://github.com/dubovikmaster/keras-model-cv.git
Quickstart
from keras_model_cv import KerasCV
from sklearn.model_selection import KFold
import tensorflow as tf
tf.get_logger().setLevel("INFO")
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
def build_model(hidden_units, dropout):
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(hidden_units, activation="relu"),
tf.keras.layers.Dropout(dropout),
tf.keras.layers.Dense(10),
]
)
model.compile(
optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=["accuracy"],
)
return model
PARAMS = {'hidden_units': 16, 'dropout': .3}
if __name__ == '__main__':
cv = KerasCV(
build_model,
KFold(n_splits=3, random_state=1234, shuffle=True),
PARAMS,
preprocessor=tf.keras.layers.Normalization(),
save_history=True,
directory='my_awesome_project',
name='my_cv',
)
cv.fit(x_train, y_train, verbose=0, epochs=3)
print(cv.get_cv_score())
Out:
loss accuracy
mean 0.283194 0.919783
std 0.004215 0.002887
You can add another aggregate function (for more info see: pandas.DataFrame.agg):
print(cv.get_cv_score(agg_func={'loss': min, 'accuracy': max}))
Out:
loss 0.27959
accuracy 0.92010
Also, you can get all train history for each splits as pandas dataframe:
cv.get_train_history()
Out:
loss accuracy split epochs
0 0.957261 0.679375 0 1
1 0.595646 0.809850 0 2
2 0.541124 0.824850 0 3
3 0.835493 0.722475 1 1
4 0.574581 0.810925 1 2
5 0.526098 0.829200 1 3
6 0.813172 0.736200 2 1
7 0.556871 0.816875 2 2
8 0.512916 0.829550 2 3
You can show train history as matplotlib plot:
cv.show_train_history()
What about metrics per splits?
cv.get_split_scores()
Out:
accuracy loss split
0 0.9201 0.282442 0
1 0.9198 0.290500 1
2 0.9173 0.279590 2
If save_history=True train history, validation metrics and info about split will be saved to the specified directory.
In our example:
my_awesome_project/
|--my_cv/
|--split_0/
|--history.yml
|--validation_metric.yml
|--split_info.yml
|--split_1/
|--split_2/
Licence
MIT license
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file keras_model_cv-0.5.4.tar.gz.
File metadata
- Download URL: keras_model_cv-0.5.4.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
23983835c43f2af35eda7a308238b55e77000505678836201e0c80628110b37f
|
|
| MD5 |
f4c44dbe311834da3add952519d3aada
|
|
| BLAKE2b-256 |
b470a29dba9d33bd949a0ae5249d2fb8f498b9e20a4b665987d1204d0fd04688
|
File details
Details for the file keras_model_cv-0.5.4-py3-none-any.whl.
File metadata
- Download URL: keras_model_cv-0.5.4-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
512c89101f1999b449f0bcb5655c7ef9e8eab7708ca7f3ce08ce62a9605f65e4
|
|
| MD5 |
f48a287cc52a9f442ce8bea84ab370c6
|
|
| BLAKE2b-256 |
12d59d06e9b433c953c7da7891fb007de86b4b92f16e1747f25674b1af4bf9bd
|