Explores time information to train a robust random forest
Project description
time-robust-forest
A Proof of concept model that explores timestamp information to train a random forest with better Out of Distribution generalization power.
Installation
pip install -U time-robust-forest
How to use it
There are a classifier and a regressor under time_robust_forest.models
. They follow the sklearn interface, which means you can quickly fit and use a model:
from time_robust_forest.models import TimeForestClassifier
features = ["x_1", "x_2"]
time_column = "periods"
target = "y"
model = TimeForestClassifier(time_column=time_column)
model.fit(training_data[features + [time_column]], training_data[target])
predictions = model.predict_proba(test_data[features])[:, 1]
There are only a few arguments that differ from a traditional Random Forest. two arguments
- time_column: the column from the input dataframe containing the time periods the model will iterate over to find the best splits (default: "period")
- min_sample_periods: the number of examples in every period the model needs to keep while it splits.
- period_criterion: how the performance in every period is going to be aggregated. Options: {"avg": average, "max": maximum, the worst case}. (default: "avg")
To use the environment-wise optimization:
from time_robust_forest.hyper_opt import env_wise_hyper_opt
params_grid = {"n_estimators": [30, 60, 120],
"max_depth": [5, 10],
"min_impurity_decrease": [1e-1, 1e-3, 0],
"min_sample_periods": [5, 10, 30],
"period_criterion": ["max", "avg"]}
model = TimeForestClassifier(time_column=time_column)
opt_param = env_wise_hyper_opt(training_data[features + [time_column]],
training_data[TARGET],
model,
time_column,
params_grid,
cv=5,
scorer=make_scorer(roc_auc_score,
needs_proba=True))
Make sure you have a good choice for the time column
Don't simply use a timestamp column from the dataset, make it discrete before and guarantee there is a reasonable amount of data points in every period. Example: use year if you have 3+ years of data. Notice the choice to make it discrete becomes a modeling choice you can optimize.
Random segments
Selecting randomly from multiple time columns
The user can use a list instead of a string as the time_column
argument. The model will select randomly from it when building every estimator from the defined n_estimators
.
from time_robust_forest.models import TimeForestClassifier
features = ["x_1", "x_2"]
time_columns = ["periods", "periods_2"]
target = "y"
model = TimeForestClassifier(time_column=time_columns)
model.fit(training_data[features + time_columns], training_data[target])
predictions = model.predict_proba(test_data[features])[:, 1]
Generating random segments from a timestamp column
The user can define a maximum number of segments (random_segments
) and the model will split the data using the time stamp information. In the following example, the model segments the data in 1, 2, 3... 10 parts. For every estimator, it picks randomly one of the ten columns representing the time_column
and use it. In this case, the time_column
should be the time stamp information.
from time_robust_forest.models import TimeForestClassifier
features = ["x_1", "x_2"]
time_column = "time_stamp"
target = "y"
model = TimeForestClassifier(time_column=time_column, random_segments=10)
model.fit(training_data[features + [time_column]], training_data[target])
predictions = model.predict_proba(test_data[features])[:, 1]
License
This project is licensed under the terms of the BSD-3
license. See LICENSE for more details.
Useful links
Citation
@misc{time-robust-forest,
author = {Moneda, Luis},
title = {Time Robust Forest model},
year = {2021},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/lgmoneda/time-robust-forest}}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for time-robust-forest-0.1.14.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | a3bae04007ee036f65defa736a42e8814a378ac1ac7e6645cb841ccfd4a84208 |
|
MD5 | ba43599f609fa46206985f27070bc2c0 |
|
BLAKE2b-256 | f475884db04e58efbcbe041f3e09854ecd846d89accf62dc91ee28b2f00a74e2 |
Hashes for time_robust_forest-0.1.14-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4ac284ae7d2d93799a54a3de628cc4a4e11b7bb7097337f92f2444d34cdd2286 |
|
MD5 | 81e65fe3760b4e9c562db599aab5cd3f |
|
BLAKE2b-256 | 22a81c86c141c95e869541d735fa48d1e99b163ae8b3212531565af37f7ace70 |