PyTorch Implementation for MetaTree: Learning a Decision Tree Algorithm with Transformers
Project description
🌲 MetaTree 🌲
Learning a Decision Tree Algorithm with Transformers (Zhuang et al. 2024).
MetaTree is a transformer-based decision tree algorithm. It learns from classical decision tree algorithms (greedy algorithm CART, optimal algorithm GOSDT), for better generalization capabilities.
Quickstart -- use MetaTree to generate decision tree models
Model is available at https://huggingface.co/yzhuang/MetaTree
- Install
metatreelib
:
pip install metatreelib
# Alternatively,
# clone then pip install -e .
# pip install git+https://github.com/EvanZhuang/MetaTree
- Use MetaTree on your datasets to generate a decision tree model
from metatree.model_metatree import LlamaForMetaTree as MetaTree
from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
from metatree.run_train import preprocess_dimension_patch
from transformers import AutoConfig
import imodels # pip install imodels
# Initialize Model
model_name_or_path = "yzhuang/MetaTree"
config = AutoConfig.from_pretrained(model_name_or_path)
model = MetaTree.from_pretrained(
model_name_or_path,
config=config,
)
decision_tree_forest = DecisionTreeForest()
# Load Datasets
X, y, feature_names = imodels.get_clean_dataset('fico', data_source='imodels')
print("Dataset Shapes X={}, y={}, Num of Classes={}".format(X.shape, y.shape, len(set(y))))
train_idx, test_idx = sklearn.model_selection.train_test_split(range(X.shape[0]), test_size=0.3, random_state=seed)
# Dimension Subsampling
feature_idx = np.random.choice(X.shape[1], 10, replace=False)
X = X[:, feature_idx]
test_X, test_y = X[test_idx], y[test_idx]
# Sample Train and Test Data
subset_idx = random.sample(train_idx, 256)
train_X, train_y = X[subset_idx], y[subset_idx]
input_x = torch.tensor(train_X, dtype=torch.float32)
input_y = torch.nn.functional.one_hot(torch.tensor(train_y)).float()
batch = {"input_x": input_x, "input_y": input_y, "input_y_clean": input_y}
batch = preprocess_dimension_patch(batch, n_feature=10, n_class=10)
model.depth = 2
outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)
decision_tree_forest.add_tree(DecisionTree(auto_dims=outputs.metatree_dimensions, auto_thresholds=outputs.tentative_splits, input_x=batch['input_x'], input_y=batch['input_y'], depth=model.depth))
print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
print("Decision Tree Thresholds: ", outputs.tentative_splits)
- Inference with the decision tree model
tree_pred = decision_tree_forest.predict(torch.tensor(test_X, dtype=torch.float32))
accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
print("MetaTree Test Accuracy: ", accuracy)
Example Usage
We show a complete example of using MetaTree at notebook
Questions?
If you have any questions related to the code or the paper, feel free to reach out to us at y5zhuang@ucsd.edu.
Citation
If you find our paper and code useful, please cite us:
@misc{zhuang2024learning,
title={Learning a Decision Tree Algorithm with Transformers},
author={Yufan Zhuang and Liyuan Liu and Chandan Singh and Jingbo Shang and Jianfeng Gao},
year={2024},
eprint={2402.03774},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file metatreelib-0.1.1.tar.gz
.
File metadata
- Download URL: metatreelib-0.1.1.tar.gz
- Upload date:
- Size: 52.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1b5727651dbea766b2225df2f26b8b1d334e9b7ce8fabb3afb836e3ecb00f9bc |
|
MD5 | e9ac1f81a87c0993a2184cc5cc32769c |
|
BLAKE2b-256 | 027ccdf81a4d059206231ace80a5280bbdf4ca7cdb660e7987c580defbb18dc5 |
File details
Details for the file metatreelib-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: metatreelib-0.1.1-py3-none-any.whl
- Upload date:
- Size: 63.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.11.8
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | efc7e0581ec6e57cfd88fd8ac1d7bfb083405b1be3a2779c562aeb8d03bc41be |
|
MD5 | 9aa3525ef5025d0a03f4da461db5f48e |
|
BLAKE2b-256 | 3f7f9ea20f954e6430942a03ab2c6749893079941dd36e9d055bb3f667fe1655 |