Skip to main content

optimization of ML pipeline using hierarchical optimization method

Project description

Documentation Status PyPI version DOI

autotab

optimize pipeline for any machine learning mdoel using hierarchical optimization method for tabular datasets.

Installation

This package can be installed using pip from pypi using following command

pip install autotab

or using github link for the latest code

python -m pip install git+https://github.com/Sara-Iftikhar/autotab.git

or using setup file, go to folder where this repoitory is downloaded

python setup.py install

Example

Click here to badge or cick here to open in colab

from ai4water.datasets import busan_beach
from skopt.plots import plot_objective
from autotab import OptimizePipeline

data = busan_beach()
input_features = data.columns.tolist()[0:-1]
output_features = data.columns.tolist()[-1:]

transformations = ['minmax', 'zscore', 'log', 'log10', 'sqrt', 'robust', 'quantile', 'none', 'scale']

pl = OptimizePipeline(
    inputs_to_transform=data.columns.tolist()[0:-1],
    parent_iterations=400,
    child_iterations=20,
    parent_algorithm='bayes',
    child_algorithm="random",
    cv_parent_hpo=True,
    eval_metric='mse',
    monitor=['r2', 'nse'],
    input_transformations = transformations,
    output_transformations = transformations,
    models=[ "LinearRegression",
            "LassoLars",
            "Lasso",
            "RandomForestRegressor",
            "HistGradientBoostingRegressor",
             "CatBoostRegressor",
             "XGBRegressor",
             "LGBMRegressor",
             "GradientBoostingRegressor",
             "ExtraTreeRegressor",
             "ExtraTreesRegressor"
             ],

    input_features=data.columns.tolist()[0:-1],
    output_features=data.columns.tolist()[-1:],
    cross_validator={"KFold": {"n_splits": 5}},
    split_random=True,
)

get version information

pl._version_info()

perform optimization

results = pl.fit(data=data, process_results=False)

print optimization report

print(pl.report())

show convergence plot

pl.optimizer_._plot_convergence(save=False)
pl.optimizer_._plot_parallel_coords(figsize=(16, 8), save=False)
_ = pl.optimizer_._plot_distributions(save=False)
pl.optimizer_.plot_importance(save=False)
pl.optimizer_.plot_importance(save=False, plot_type="bar")
_ = plot_objective(results)
pl.optimizer._plot_evaluations(save=False)
pl.optimizer._plot_edf(save=False)
pl.dumbbell_plot(data=data)
pl.dumbbell_plot(data=data, metric_name='r2')
pl.taylor_plot(data=data, save=False, figsize=(6,6))
pl.compare_models()
pl.compare_models(plot_type="bar_chart")
pl.compare_models("r2", plot_type="bar_chart")

get best pipeline with respect to evaluation metric

pl.get_best_pipeline_by_metric('r2')

build fit and evaluate the best pipeline

model = pl.bfe_best_model_from_scratch(data=data)
pl.evaluate_model(model, data=data)
pl.evaluate_model(model, data=data, metric_name='nse')
pl.evaluate_model(model, data=data, metric_name='r2')

get best pipeline with respect to $R^2$

pl.get_best_pipeline_by_metric('r2')
model = pl.bfe_best_model_from_scratch(data=data, metric_name='r2')
pl.evaluate_model(model, data=data, metric_name='r2')
print(f"all results are save in {pl.path} folder")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

autotab-0.11.tar.gz (28.0 kB view details)

Uploaded Source

File details

Details for the file autotab-0.11.tar.gz.

File metadata

  • Download URL: autotab-0.11.tar.gz
  • Upload date:
  • Size: 28.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.0 CPython/3.9.12

File hashes

Hashes for autotab-0.11.tar.gz
Algorithm Hash digest
SHA256 3aabe86a56135ed5dc4f4b4c95f5c291d04146e8f4e6184571eb2ae78b96da60
MD5 e8deb88319eadfbf2717a63ca1876192
BLAKE2b-256 ded279a0ce118e90466cbc8bd71641a08df10bd7c113dd5773e99cf13afbf32e

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page