Automated ML model training
Project description
Xingu for automated ML model training
Xingu is a framework of 3 classes that provides a standard to organize and run Machine Learning training pipelines.
Notebooks are useful in EDA time, but when the modeling is ready to become a product, use Xingu proposed classes to organize interactions with DB (queries), data cleanup, feature engineering, hyper-parameters optimization, training algorithm, general and custom metrics computation, estimation post-processing.
Install
pip install https://github.com/avibrazil/xingu
Use to Train a Model
Check your project has the necessary files:
$ find
dataproviders/
dataproviders/my_dataprovider.py
estimators/
estimators/myrandomestimator.py
Train with DataProviders id_of_my_dataprovider1
and id_of_my_dataprovider2
, both defined in dataproviders/my_dataprovider.py
:
$ xingu \
--dps id_of_my_dataprovider1,id_of_my_dataprovider2 \
--datalake-athena "awsathena+rest://athena.us..." \
--query-cache-path data \
--trained-models-path models \
--debug
Procedures defined by Xingu
Steps marked with 💫 are were you put your code. All the rest is Xingu boilerplate code ready to use.
Coach.team_train()
:
Train various Models, all possible in parallel.
Coach.team_train_parallel()
(background, parallelism controled byPARALLEL_TRAIN_MAX_WORKERS
):Coach.team_load()
(for pre-req models not trained in this session)- Per DataProvider requested to be trained:
Coach.team_train_member()
(background):Model.fit()
calls:- 💫
DataProvider.get_dataset_sources_for_train()
return dict of queries Model.data_sources_to_data(sources)
- 💫
DataProvider.clean_data_for_train(dict of DataFrames)
- 💫
DataProvider.feature_engineering_for_train(DataFrame)
- 💫
DataProvider.last_pre_process_for_train(DataFrame)
- 💫
DataProvider.data_split_for_train(DataFrame)
return tuple of dataframes Model.hyperparam_optimize()
(decide origin of hyperparam)- 💫
DataProvider.get_estimator_features_list()
- 💫
DataProvider.get_target()
- 💫
DataProvider.get_estimator_optimization_search_space()
- 💫
DataProvider.get_estimator_hyperparameters()
- 💫
Estimator.hyperparam_optimize()
(SKOpt, GridSearch et all) - 💫
Estimator.hyperparam_exchange()
- 💫
- 💫
Estimator.fit()
- 💫
DataProvider.post_process_after_train()
- 💫
Coach.post_train_parallel()
(background, only ifPOST_PROCESS=true
):- Per trained Model (parallelism controled by
PARALLEL_POST_PROCESS_MAX_WORKERS
):
- Per trained Model (parallelism controled by
Coach.team_batch_predict()
:
Load from storage and use various pre-trained Models to estimate data from a pre-defined SQL query. The batch predict SQL query is defined into the DataProvider and this process will query the database to get it.
Coach.team_load()
(for all requested DPs and their pre-reqs)- Per loaded model:
Coach.single_batch_predict()
(background)Model.batch_predict()
- 💫
DataProvider.get_dataset_sources_for_batch_predict()
Model.data_sources_to_data()
- 💫
DataProvider.clean_data_for_batch_predict()
- 💫
DataProvider.feature_engineering_for_batch_predict()
- 💫
DataProvider.last_pre_process_for_batch_predict()
Model.predict_proba()
orModel.predict()
(see below)
- 💫
Model.compute_and_save_metrics(channel=batch_predict
(see below)Model.save_batch_predict_estimations()
Model.predict()
and Model.predict_proba()
:
Model.generic_predict()
- 💫
DataProvider.pre_process_for_predict()
orDataProvider.pre_process_for_predict_proba()
- 💫
DataProvider.get_estimator_features_list()
- 💫
Estimator.predict()
orEstimator.predict_proba()
- 💫
DataProvider.post_process_after_predict()
orDataProvider.post_process_after_predict_proba()
- 💫
Model.compute_and_save_metrics()
:
Sub-system to compute various metrics, graphics and transformations over a facet of the data.
This is executed right after a Model was trained and also during a batch predict.
Predicted data is computed before Model.compute_and_save_metrics()
is called.
By Model.trainsets_predict()
and Model.batch_predict()
.
Model.save_model_metrics()
calls:Model.compute_model_metrics()
calls:Model.compute_trainsets_model_metrics()
calls:- All
Model.compute_trainsets_model_metrics_{NAME}()
- All 💫
DataProvider.compute_trainsets_model_metrics_{NAME}()
- All
Model.compute_batch_model_metrics()
calls:- All
Model.compute_batch_model_metrics_{NAME}()
- All 💫
DataProvider.compute_batch_model_metrics_{NAME}()
- All
Model.compute_global_model_metrics()
calls:- All
Model.compute_global_model_metrics_{NAME}()
- All 💫
DataProvider.compute_global_model_metrics_{NAME}()
- All
Model.render_model_plots()
calls:Model.render_trainsets_model_plots()
calls:- All
Model.render_trainsets_model_plots_{NAME}()
- All 💫
DataProvider.render_trainsets_model_plots_{NAME}()
- All
Model.render_batch_model_plots()
calls:- All
Model.render_batch_model_plots_{NAME}()
- All 💫
DataProvider.render_batch_model_plots_{NAME}()
- All
Model.render_global_model_plots()
calls:- All
Model.render_global_model_plots_{NAME}()
- All 💫
DataProvider.render_global_model_plots_{NAME}()
- All
Model.save_estimation_metrics()
calls:Model.compute_estimation_metrics()
calls:- All
Model.compute_estimation_metrics_{NAME}()
- All 💫
DataProvider.compute_estimation_metrics_{NAME}()
- All
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.