Tree prompting
Project description
Tree Prompting
Tree Prompting: Efficient Task Adaptation without Fine-Tuning, code for the Tree-prompt paper.
Tree Prompting uses training examples to learn a tree of prompts to make a classification, yielding higher accuracy and better efficiency that baseline ensembles.
Quickstart
Installation: pip install treeprompt
(or clone this repo and pip install -e .
)
from treeprompt.treeprompt import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
len(dset_val), size=100, replace=False))
# set up arguments
prompts = [
"This movie is",
" Positive or Negative? The movie was",
" The sentiment of the movie was",
" The plot of the movie was really",
" The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier(
checkpoint=checkpoint,
prompts=prompts,
verbalizer=verbalizer,
cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])
# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
np.mean(preds == dset_val['label'])) # -> 0.7
# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree(
m.clf_,
fontsize=10,
feature_names=m.feature_names_,
class_names=list(verbalizer.values()),
filled=True,
)
plt.show()
Reference:
@misc{ch2022augmenting,
title={Tree Prompting: Efficient Task Adaptation without Fine-Tuning},
year={2023},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
Reproducing experiments
See the full code for reproducing all experiments in the paper at https://github.com/csinva/tree-prompt-experiments
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
treeprompt-0.0.1.tar.gz
(11.1 kB
view hashes)
Built Distribution
treeprompt-0.0.1-py3-none-any.whl
(11.3 kB
view hashes)
Close
Hashes for treeprompt-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a0ccb885afb316917e70156ee7672b93c789c47da0200b4884c3a973bb6237e2 |
|
MD5 | 8cc763eb2b0ee8141869e308eb90eaff |
|
BLAKE2b-256 | d1b06652aa08edb12570f11af7b00cbe90b4d9abfdc3f989d5dfa235a43011d2 |