Inspect machine learning models
Project description
Model Inspector
model_inspector
aims to help you train better scikit-learn
models by
providing insights into their behavior.
Use
To use model_inspector
, you create an Inspector
object from a
scikit-learn
model, a feature DataFrame X
, and a target Series y
,
as illustrated below.
import sklearn.datasets
from sklearn.ensemble import RandomForestRegressor
from model_inspector import get_inspector
X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
rfr = RandomForestRegressor().fit(X, y)
inspector = get_inspector(rfr, X, y)
You can then use various methods of inspector
to learn about your
model.
inspector.permutation_importance()
s5 0.492319
bmi 0.486712
bp 0.138480
s6 0.095193
s3 0.077982
age 0.072870
s2 0.066400
s1 0.049787
s4 0.026105
sex 0.024565
dtype: float64
ax = inspector.plot_permutation_importance()
most_important_features = inspector.permutation_importance().index[:2]
axes = inspector.plot_partial_dependence(
features=[*most_important_features, most_important_features]
)
axes[0, 0].get_figure().set_size_inches(12, 3)
inspector.show_correlation()
<style type="text/css">
#T_3ba74_row0_col0, #T_3ba74_row1_col1, #T_3ba74_row2_col2, #T_3ba74_row3_col3, #T_3ba74_row4_col4, #T_3ba74_row5_col5, #T_3ba74_row6_col6, #T_3ba74_row7_col7, #T_3ba74_row8_col8, #T_3ba74_row9_col9, #T_3ba74_row10_col10 {
background-color: #ff0000;
color: #f1f1f1;
}
#T_3ba74_row0_col1, #T_3ba74_row1_col0, #T_3ba74_row5_col10, #T_3ba74_row10_col5 {
background-color: #ffd2d2;
color: #000000;
}
#T_3ba74_row0_col2, #T_3ba74_row2_col0, #T_3ba74_row3_col5, #T_3ba74_row5_col3 {
background-color: #ffd0d0;
color: #000000;
}
#T_3ba74_row0_col3, #T_3ba74_row1_col7, #T_3ba74_row3_col0, #T_3ba74_row7_col1 {
background-color: #ffaaaa;
color: #000000;
}
#T_3ba74_row0_col4, #T_3ba74_row2_col5, #T_3ba74_row4_col0, #T_3ba74_row5_col2 {
background-color: #ffbcbc;
color: #000000;
}
#T_3ba74_row0_col5, #T_3ba74_row5_col0 {
background-color: #ffc6c6;
color: #000000;
}
#T_3ba74_row0_col6, #T_3ba74_row6_col0 {
background-color: #ececff;
color: #000000;
}
#T_3ba74_row0_col7, #T_3ba74_row1_col9, #T_3ba74_row7_col0, #T_3ba74_row9_col1 {
background-color: #ffcaca;
color: #000000;
}
#T_3ba74_row0_col8, #T_3ba74_row8_col0 {
background-color: #ffbaba;
color: #000000;
}
#T_3ba74_row0_col9, #T_3ba74_row9_col0 {
background-color: #ffb2b2;
color: #000000;
}
#T_3ba74_row0_col10, #T_3ba74_row10_col0 {
background-color: #ffcece;
color: #000000;
}
#T_3ba74_row1_col2, #T_3ba74_row2_col1 {
background-color: #ffe8e8;
color: #000000;
}
#T_3ba74_row1_col3, #T_3ba74_row3_col1 {
background-color: #ffc2c2;
color: #000000;
}
#T_3ba74_row1_col4, #T_3ba74_row4_col1 {
background-color: #fff6f6;
color: #000000;
}
#T_3ba74_row1_col5, #T_3ba74_row5_col1 {
background-color: #ffdada;
color: #000000;
}
#T_3ba74_row1_col6, #T_3ba74_row6_col1 {
background-color: #9e9eff;
color: #f1f1f1;
}
#T_3ba74_row1_col8, #T_3ba74_row8_col1 {
background-color: #ffd8d8;
color: #000000;
}
#T_3ba74_row1_col10, #T_3ba74_row10_col1 {
background-color: #fff4f4;
color: #000000;
}
#T_3ba74_row2_col3, #T_3ba74_row3_col2, #T_3ba74_row3_col8, #T_3ba74_row8_col3 {
background-color: #ff9a9a;
color: #000000;
}
#T_3ba74_row2_col4, #T_3ba74_row3_col4, #T_3ba74_row4_col2, #T_3ba74_row4_col3 {
background-color: #ffc0c0;
color: #000000;
}
#T_3ba74_row2_col6, #T_3ba74_row6_col2 {
background-color: #a2a2ff;
color: #f1f1f1;
}
#T_3ba74_row2_col7, #T_3ba74_row7_col2 {
background-color: #ff9696;
color: #000000;
}
#T_3ba74_row2_col8, #T_3ba74_row8_col2 {
background-color: #ff8c8c;
color: #000000;
}
#T_3ba74_row2_col9, #T_3ba74_row3_col9, #T_3ba74_row9_col2, #T_3ba74_row9_col3 {
background-color: #ff9c9c;
color: #000000;
}
#T_3ba74_row2_col10, #T_3ba74_row10_col2 {
background-color: #ff6868;
color: #f1f1f1;
}
#T_3ba74_row3_col6, #T_3ba74_row6_col3 {
background-color: #d2d2ff;
color: #000000;
}
#T_3ba74_row3_col7, #T_3ba74_row7_col3 {
background-color: #ffbebe;
color: #000000;
}
#T_3ba74_row3_col10, #T_3ba74_row10_col3 {
background-color: #ff8e8e;
color: #000000;
}
#T_3ba74_row4_col5, #T_3ba74_row5_col4 {
background-color: #ff1a1a;
color: #f1f1f1;
}
#T_3ba74_row4_col6, #T_3ba74_row6_col4 {
background-color: #fff2f2;
color: #000000;
}
#T_3ba74_row4_col7, #T_3ba74_row7_col4 {
background-color: #ff7474;
color: #f1f1f1;
}
#T_3ba74_row4_col8, #T_3ba74_row8_col4 {
background-color: #ff7c7c;
color: #f1f1f1;
}
#T_3ba74_row4_col9, #T_3ba74_row9_col4 {
background-color: #ffacac;
color: #000000;
}
#T_3ba74_row4_col10, #T_3ba74_row10_col4 {
background-color: #ffc8c8;
color: #000000;
}
#T_3ba74_row5_col6, #T_3ba74_row6_col5 {
background-color: #ccccff;
color: #000000;
}
#T_3ba74_row5_col7, #T_3ba74_row7_col5 {
background-color: #ff5656;
color: #f1f1f1;
}
#T_3ba74_row5_col8, #T_3ba74_row8_col5 {
background-color: #ffaeae;
color: #000000;
}
#T_3ba74_row5_col9, #T_3ba74_row9_col5 {
background-color: #ffb4b4;
color: #000000;
}
#T_3ba74_row6_col7, #T_3ba74_row7_col6 {
background-color: #4242ff;
color: #f1f1f1;
}
#T_3ba74_row6_col8, #T_3ba74_row8_col6 {
background-color: #9898ff;
color: #f1f1f1;
}
#T_3ba74_row6_col9, #T_3ba74_row9_col6 {
background-color: #b8b8ff;
color: #000000;
}
#T_3ba74_row6_col10, #T_3ba74_row10_col6 {
background-color: #9a9aff;
color: #f1f1f1;
}
#T_3ba74_row7_col8, #T_3ba74_row8_col7 {
background-color: #ff6060;
color: #f1f1f1;
}
#T_3ba74_row7_col9, #T_3ba74_row9_col7 {
background-color: #ff9494;
color: #000000;
}
#T_3ba74_row7_col10, #T_3ba74_row10_col7 {
background-color: #ff9090;
color: #000000;
}
#T_3ba74_row8_col9, #T_3ba74_row9_col8 {
background-color: #ff8888;
color: #f1f1f1;
}
#T_3ba74_row8_col10, #T_3ba74_row10_col8 {
background-color: #ff6e6e;
color: #f1f1f1;
}
#T_3ba74_row9_col10, #T_3ba74_row10_col9 {
background-color: #ff9e9e;
color: #000000;
}
</style>
age | sex | bmi | bp | s1 | s2 | s3 | s4 | s5 | s6 | target | |
---|---|---|---|---|---|---|---|---|---|---|---|
age | 1.00 | 0.17 | 0.19 | 0.34 | 0.26 | 0.22 | -0.08 | 0.20 | 0.27 | 0.30 | 0.19 |
sex | 0.17 | 1.00 | 0.09 | 0.24 | 0.04 | 0.14 | -0.38 | 0.33 | 0.15 | 0.21 | 0.04 |
bmi | 0.19 | 0.09 | 1.00 | 0.40 | 0.25 | 0.26 | -0.37 | 0.41 | 0.45 | 0.39 | 0.59 |
bp | 0.34 | 0.24 | 0.40 | 1.00 | 0.24 | 0.19 | -0.18 | 0.26 | 0.39 | 0.39 | 0.44 |
s1 | 0.26 | 0.04 | 0.25 | 0.24 | 1.00 | 0.90 | 0.05 | 0.54 | 0.52 | 0.33 | 0.21 |
s2 | 0.22 | 0.14 | 0.26 | 0.19 | 0.90 | 1.00 | -0.20 | 0.66 | 0.32 | 0.29 | 0.17 |
s3 | -0.08 | -0.38 | -0.37 | -0.18 | 0.05 | -0.20 | 1.00 | -0.74 | -0.40 | -0.27 | -0.39 |
s4 | 0.20 | 0.33 | 0.41 | 0.26 | 0.54 | 0.66 | -0.74 | 1.00 | 0.62 | 0.42 | 0.43 |
s5 | 0.27 | 0.15 | 0.45 | 0.39 | 0.52 | 0.32 | -0.40 | 0.62 | 1.00 | 0.46 | 0.57 |
s6 | 0.30 | 0.21 | 0.39 | 0.39 | 0.33 | 0.29 | -0.27 | 0.42 | 0.46 | 1.00 | 0.38 |
target | 0.19 | 0.04 | 0.59 | 0.44 | 0.21 | 0.17 | -0.39 | 0.43 | 0.57 | 0.38 | 1.00 |
ax = inspector.plot_feature_clusters()
The methods that are available for a given inspector depends on the
types of its estimator and its target y
. An attribute called methods
tells you what they are:
inspector.methods
['permutation_importance',
'plot_dependence',
'plot_feature_clusters',
'plot_permutation_importance',
'plot_pred_vs_act',
'plot_residuals',
'show_correlation']
Install
pip install model_inspector
Alternatives
The most similar library to model_inspector that I am aware of is
Yellowbrick. Both are machine
learning visualization libraries designed to extend scikit-learn
.
Yellowbrick is designed around Visualizer
objects. Each Visualizer
corresponds to a single type of visualization. The Visualizer
interface is similar to the scikit-learn
transformer and estimator
interfaces.
model_inspector
takes a different approach. It is designed around
Inspector
objects that bundle together a scikit-learn
model, an X
feature DataFrame, and a y
target Series. The Inspector
object does
the work of identifying appropriate visualization types for the specific
model and dataset in question and exposing corresponding methods, making
it easy to visualize a given model for a given dataset in a variety of
ways.
Acknowledgments
Many aspects of this library were inspired by FastAI courses, including bundling together a model with data in a class and providing certain specific visualization methods such as feature importance bar plots, feature clusters dendrograms, tree diagrams, waterfall plots, and partial dependence plots. Its primary contribution is to make all of these methods available in a single convenient interface.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for model_inspector-0.27.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b26e06b2afa0304c07e6a2ec078f9240fdab7ce1f1cd03d6406a1758bd2df947 |
|
MD5 | 05336b94c2a35b60e8077cf29346bc15 |
|
BLAKE2b-256 | c27a6630e25ea37d7169a6e19bdd9cd6c85af51a313d7ebe1b51bfd72aae80b4 |