Skip to main content

RandomizedSearchCV/GridSearchCV with pandas.DataFrame interface

Project description

sklearn-cv-pandas

RandomizedSearchCV/GridSearchCV with pandas.DataFrame interface

Why do I want this?

  • I usually prepare features as pandas.DataFrame
  • Scikit learn input should be array-like. https://scikit-learn.org/stable/glossary.html#term-array-like.
  • Although it includes pandas.DataFrame, there are some issues;
    • It does not support Int64 data type
    • Output model does not remember which columns should be used

Solution

  • Provide GridSearchCV / RandomizedSearchCV with pandas.DataFrame interface
    • Internally preprocess DataFrame to be applicable for sklearn
  • Output of fit command is now original Model object, which
    • stores column name information
    • provides pandas.DataFrame interface for prediction

Installation

pip install sklearn_cv_pandas

Usage

Configure CV object

Instantiate CV in the same manner as original ones.

from scipy import stats
from sklearn import linear_model
from sklearn_cv_pandas import RandomizedSearchCV

estimator = linear_model.Lasso()
param_dist = dict(alpha=stats.loguniform(1e-5, 10))
cv = RandomizedSearchCV(estimator, param_dist, scoring="mean_absolute_error")

fit with pandas.DataFrame

Our CV object has new methods fit_holdout_pandas and fit_cv_pandas. Original ones requires x and y as numpy.array. Instead of numpy array, you can specify one pandas.DataFrame and column names for x (feature_columns), and column name of y (target_column).

model = cv.fit_cv_pandas(
    df, target_column="y", feature_columns=["x{}".format(i) for i in range(100)], n_fold=5
)

predict with pandas.DataFrame

You can run prediction with pandas.DataFrame interface as well. Output of fit_holdout_pandas and fit_cv_pandas stores feature_columns and target_column. You can just input pandas.DataFrame for prediction into the method predict.

model.predict(df)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sklearn_cv_pandas-0.0.12.tar.gz (6.7 kB view details)

Uploaded Source

Built Distribution

sklearn_cv_pandas-0.0.12-py3-none-any.whl (6.4 kB view details)

Uploaded Python 3

File details

Details for the file sklearn_cv_pandas-0.0.12.tar.gz.

File metadata

  • Download URL: sklearn_cv_pandas-0.0.12.tar.gz
  • Upload date:
  • Size: 6.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.6

File hashes

Hashes for sklearn_cv_pandas-0.0.12.tar.gz
Algorithm Hash digest
SHA256 2332de99179047c65265b49a991e51324982a96e992050fe7ec2be8b30ece363
MD5 c61543c5f9d56c38a5d66b52ca1f625c
BLAKE2b-256 052cb1b023e8f187f264bd2cef41e91c0e9f5dbd88b9ecee7216263e91de1666

See more details on using hashes here.

File details

Details for the file sklearn_cv_pandas-0.0.12-py3-none-any.whl.

File metadata

File hashes

Hashes for sklearn_cv_pandas-0.0.12-py3-none-any.whl
Algorithm Hash digest
SHA256 1ce37900b15b9faf98c11b6e415ab13e5b0fcfe27557e7471a56df1aea144b91
MD5 c476a8ac43aaeb3e26f51e38cec76090
BLAKE2b-256 10999378c1b940380eef4f459f9dd5902496a16bfd5e2236d5e514ee841ac89a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page