Growing Robust Decision Trees
Project description
GROOT: Growing Robust Trees
Growing Robust Trees (GROOT) is an algorithm that fits binary classification decision trees such that they are robust against user-specified adversarial examples. The algorithm closely resembles algorithms used for fitting normal decision trees (i.e. CART) but changes the splitting criterion and the way samples propagate when creating a split.
This repository contains the module groot that implements GROOT as a Scikit-learn compatible classifier, an adversary for model evaluation and easy functions to import datasets. For documentation see https://groot.cyber-analytics.nl
Simple example
To train and evaluate GROOT on a toy dataset against an attacker that can move samples by 0.5 in each direction one can use the following code:
from groot.adversary import DecisionTreeAdversary
from groot.model import GrootTree
from sklearn.datasets import make_moons
X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)
attack_model = [0.5, 0.5]
is_numerical = [True, True]
tree = GrootTree(attack_model=attack_model, is_numerical=is_numerical, random_state=0)
tree.fit(X, y)
accuracy = tree.score(X_test, y_test)
adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)
print("Accuracy:", accuracy)
print("Adversarial Accuracy:", adversarial_accuracy)
Installation
groot can be installed from PyPi:
pip install groot-trees
To use Kantchelian's MILP attack it is required that you have GUROBI installed along with their python package:
python -m pip install -i https://pypi.gurobi.com gurobipy
Specific dependency versions
To reproduce our experiments with exact package versions you can clone the repository and run:
pip install -r requirements.txt
We recommend using virtual environments.
Reproducing 'Efficient Training of Robust Decision Trees Against Adversarial Examples' (article)
To reproduce the results from the paper we provide generate_k_fold_results.py, a script that takes the trained models (from JSON format) and generates tables and figures. The resulting figures generate under /out/.
To not only generate the results but to also retrain all models we include the scripts train_kfold_models.py and fit_chen_xgboost.py. The first script runs the algorithms in parallel for each dataset then outputs to /out/trees/ and /out/forests/. Warning: the script can take a long time to run (about a day given 16 cores). The second script train specifically the Chen et al. boosting ensembles. /out/results.zip contains all results from when we ran the scripts.
To experiment on image datasets we have a script image_experiments.py that fits and output the results. In this script, one can change the dataset variable to 'mnist' or 'fmnist' to switch between the two.
The scripts summarize_datasets.py and visualize_threat_models.py output some figures we used in the text.
Implementation details
The TREANT implementation (groot.treant.py) is copied almost completely from the authors of TREANT at https://github.com/gtolomei/treant with small modifications to better interface with the experiments. The heuristic by Chen et al. runs in the GROOT code, only with a different score function. This score function can be enabled by setting chen_heuristic=True on a GrootTree before calling .fit(X, y). The provably robust boosting implementation comes almost completely from their code at https://github.com/max-andr/provably-robust-boosting and we use a small wrapper around their code (groot.provably_robust_boosting.wrapper.py) to use it. When we recorded the runtimes we turned off all parallel options in the @jit annotations from the code. The implementation of Chen et al. boosting can be found in their own repo https://github.com/chenhongge/RobustTrees, from whic we need to compile and copy the binary xgboost to the current directory. The script fit_chen_xgboost.py then calls this binary and uses the command line interface to fit all models.
Important note on TREANT
To encode L-infinity norms correctly we had to modify TREANT to NOT apply rules recursively. This means we added a single break statement in the treant.Attacker.__compute_attack() method. If you are planning on using TREANT with recursive attacker rules then you should remove this statement or use TREANT's unmodified code at https://github.com/gtolomei/treant .
Contact
For any questions or comments please create an issue or contact me directly.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file groot-trees-0.0.3.tar.gz.
File metadata
- Download URL: groot-trees-0.0.3.tar.gz
- Upload date:
- Size: 79.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.1 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.1
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f6024687c42ec5bd988aa6f37c1b92ef27c2b9106fbda52e0865e4e71a3d168
|
|
| MD5 |
61c8aa34f072811a3fc2f1beebfc13ea
|
|
| BLAKE2b-256 |
11fd146f4a61545a97d23704b8db6a448d8e02114208a043ce6444b7622567ca
|