Skip to main content

PyTorch Implementation for MetaTree: Learning a Decision Tree Algorithm with Transformers

Project description

🌲 MetaTree 🌲

Learning a Decision Tree Algorithm with Transformers (Zhuang et al. 2024).

MetaTree is a transformer-based decision tree algorithm. It learns from classical decision tree algorithms (greedy algorithm CART, optimal algorithm GOSDT), for better generalization capabilities.

Quickstart -- use MetaTree to generate decision tree models

Model is available at https://huggingface.co/yzhuang/MetaTree

  1. Install metatreelib:
pip install metatreelib
# Alternatively,  
# clone then pip install -e .
# pip install git+https://github.com/EvanZhuang/MetaTree
  1. Use MetaTree on your datasets to generate a decision tree model
from metatree.model_metatree import LlamaForMetaTree as MetaTree
from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
from metatree.run_train import preprocess_dimension_patch
from transformers import AutoConfig
import imodels # pip install imodels 

# Initialize Model
model_name_or_path = "yzhuang/MetaTree"

config = AutoConfig.from_pretrained(model_name_or_path)
model = MetaTree.from_pretrained(
    model_name_or_path,
    config=config,
)
decision_tree_forest = DecisionTreeForest()   

# Load Datasets
X, y, feature_names = imodels.get_clean_dataset('fico', data_source='imodels')

print("Dataset Shapes X={}, y={}, Num of Classes={}".format(X.shape, y.shape, len(set(y))))

train_idx, test_idx = sklearn.model_selection.train_test_split(range(X.shape[0]), test_size=0.3, random_state=seed)

# Dimension Subsampling
feature_idx = np.random.choice(X.shape[1], 10, replace=False)
X = X[:, feature_idx]

test_X, test_y = X[test_idx], y[test_idx]

# Sample Train and Test Data
subset_idx = random.sample(train_idx, 256)
train_X, train_y = X[subset_idx], y[subset_idx]

input_x = torch.tensor(train_X, dtype=torch.float32)
input_y = torch.nn.functional.one_hot(torch.tensor(train_y)).float()

batch = {"input_x": input_x, "input_y": input_y, "input_y_clean": input_y}
batch = preprocess_dimension_patch(batch, n_feature=10, n_class=10)
model.depth = 2
outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)
decision_tree_forest.add_tree(DecisionTree(auto_dims=outputs.metatree_dimensions, auto_thresholds=outputs.tentative_splits, input_x=batch['input_x'], input_y=batch['input_y'], depth=model.depth))

print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
print("Decision Tree Thresholds: ", outputs.tentative_splits)
  1. Inference with the decision tree model
tree_pred = decision_tree_forest.predict(torch.tensor(test_X, dtype=torch.float32))

accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
print("MetaTree Test Accuracy: ", accuracy)

Example Usage

We show a complete example of using MetaTree at notebook

Questions?

If you have any questions related to the code or the paper, feel free to reach out to us at y5zhuang@ucsd.edu.

Citation

If you find our paper and code useful, please cite us:

@misc{zhuang2024learning,
      title={Learning a Decision Tree Algorithm with Transformers}, 
      author={Yufan Zhuang and Liyuan Liu and Chandan Singh and Jingbo Shang and Jianfeng Gao},
      year={2024},
      eprint={2402.03774},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

metatreelib-0.1.1.tar.gz (52.8 kB view details)

Uploaded Source

Built Distribution

metatreelib-0.1.1-py3-none-any.whl (63.7 kB view details)

Uploaded Python 3

File details

Details for the file metatreelib-0.1.1.tar.gz.

File metadata

  • Download URL: metatreelib-0.1.1.tar.gz
  • Upload date:
  • Size: 52.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.8

File hashes

Hashes for metatreelib-0.1.1.tar.gz
Algorithm Hash digest
SHA256 1b5727651dbea766b2225df2f26b8b1d334e9b7ce8fabb3afb836e3ecb00f9bc
MD5 e9ac1f81a87c0993a2184cc5cc32769c
BLAKE2b-256 027ccdf81a4d059206231ace80a5280bbdf4ca7cdb660e7987c580defbb18dc5

See more details on using hashes here.

File details

Details for the file metatreelib-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: metatreelib-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 63.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.8

File hashes

Hashes for metatreelib-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 efc7e0581ec6e57cfd88fd8ac1d7bfb083405b1be3a2779c562aeb8d03bc41be
MD5 9aa3525ef5025d0a03f4da461db5f48e
BLAKE2b-256 3f7f9ea20f954e6430942a03ab2c6749893079941dd36e9d055bb3f667fe1655

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page