Skip to main content

A selective ensemble for predictive models that tests new additions to prevent downgrades in performance.

Project description

clique-ml

A selective ensemble for predictive time-series models that tests new additions to prevent downgrades in performance.

This code was written and tested against a CUDA 12.2 environment; if you run into compatability issues, try setting up a venv using cuda-venv.sh.

Usage

Setup

pip install -U clique-ml
import clique

Training

Create a list of models to train. Supports any class that can call fit(), predict(), and get_params():

import xgboost as xgb
import lightgbm as lgb
import catboost as cat
import tensorflow as tf

models = [
    xgb.XGBRegressor(...),
    lgb.LGBMRegressor(...),
    cat.CatBoostRegressor(...),
    tf.keras.Sequential(...),
]

Data is automatically split for training, testing, and validaiton, so simply pass models, inputs (X) and targets(y) to train_ensemble():

X, y = ... # preprocessed data; 20% is set aside for validation, and the rest is trained on using k-folds

ensemble = clique.train_ensemble(models, X, y, folds=5, limit=3) # instance of clique.SelectiveEnsemble

folds sets n_splits for scikit-learn's TimeSeriesSplit class, which is used to implement k-folds here. For a single split, pass folds=1.

limit sets a soft target for how many models to include in the ensemble. When set, once exceeded, the ensemble will reject new models that raise its mean score.

By default, the ensemble trains using 5 folds and no size limit.

Evaluation

train_ensemble() will output the results of each sub-model's training on every fold:

Pre-training setup...Complete (0.0s)
Model 1/5: Fold 1/5: Stopped: PredictionError: Model is guessing a constant value. -- 3      
Model 2/5: Fold 1/5: Stopped: PredictionError: Model is guessing a constant value. -- 3       
Model 3/5: Fold 1/5: Accepted with score: 0.03233311 (0.1s) (CatBoostRegressor_1731893049_0)          
Model 3/5: Fold 2/5: Accepted with score: 0.02314115 (0.0s) (CatBoostRegressor_1731893050_1)          
Model 3/5: Fold 3/5: Accepted with score: 0.01777214 (0.0s) (CatBoostRegressor_1731893050_2)
...      
Model 5/5: Fold 2/5: Rejected with score: 0.97019375 (0.3s)                            
Model 5/5: Fold 3/5: Rejected with score: 0.41385662 (1.4s)                         
Model 5/5: Fold 4/5: Rejected with score: 0.41153231 (0.8s)          
Model 5/5: Fold 5/5: Rejected with score: 0.40335007 (1.6s)

Once trained, details of the final ensemble can be reviewed with:

print(ensemble) # <SelectiveEnsemble (5 model(s); mean: 0.03389993; best: 0.03321487; limit: 3)>

Or:

print(len(ensemble)) # 5
print(ensemble.mean_score) # 0.033899934449981864
print(ensemble.best_score) # 0.033214874389494775

Pruning

Since SelectiveEnsemble has to accept the first N models to establish a mean, frontloading with weaker models may cause oversaturation, even when limit is set.

To remedy this, call SelectiveEnsemble.prune():

pruned = ensemble.prune()

Which will return a copy of the ensemble with all sub-models scoring above the mean removed.

If a limit is passed in, the removal of all models above the mean will recurse until that limit is reached:

triumvate = ensemble.prune(3)
print(len(ensemble)) # 3 (or less)

This recursion is automatic for instances where SelectiveEnsemble.limit is set manually or by train_ensemble().

Deployment

To make predictions, simply call:

predictions = ensemble.predict(...) # with a new set of inputs

Which will use the mean score across all sub-models for each prediction.

If you wish to continue training on an existing ensemble, use:

existing = clique.load_ensemble(X_test=X, y_test=y) # test data must be passed in for new model evaluation
updated = clique.train_ensemble(models, X, y, ensemble=existing)

Note that if a limit is set on the existing model, that will be set and enforced on the updated one.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

clique_ml-0.0.5.tar.gz (7.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

clique_ml-0.0.5-py3-none-any.whl (6.9 kB view details)

Uploaded Python 3

File details

Details for the file clique_ml-0.0.5.tar.gz.

File metadata

  • Download URL: clique_ml-0.0.5.tar.gz
  • Upload date:
  • Size: 7.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for clique_ml-0.0.5.tar.gz
Algorithm Hash digest
SHA256 28a1379bc6f570d2afc7e89f6ab6d56d0649a38028ae55522658e6b10e8f74dc
MD5 c706a62391883e21aa4a77a4891abed5
BLAKE2b-256 9afddb9ada78515cad3eae270418ca14f6397081b8f68e174bf4e625dfac50d8

See more details on using hashes here.

File details

Details for the file clique_ml-0.0.5-py3-none-any.whl.

File metadata

  • Download URL: clique_ml-0.0.5-py3-none-any.whl
  • Upload date:
  • Size: 6.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.10.12

File hashes

Hashes for clique_ml-0.0.5-py3-none-any.whl
Algorithm Hash digest
SHA256 643366b29ad1e4d9475c80e925b3242c49d6ac9718d72744b3a5dc0b50d0e284
MD5 5258948f54b6a41fad3b62e42f8ef4b0
BLAKE2b-256 a7029e2ac61fc28ac8b79d87fffadc2033fb5876471ba9f032e8552c8a8dfdb9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page