Tree prompting
Project description
Tree Prompting
Tree Prompting: Efficient Task Adaptation without Fine-Tuning, code for the Tree-prompt paper.
Tree Prompting uses training examples to learn a tree of prompts to make a classification, yielding higher accuracy and better efficiency that baseline ensembles.
Quickstart
Installation: pip install treeprompt (or clone this repo and pip install -e .)
from treeprompt.treeprompt import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
len(dset_val), size=100, replace=False))
# set up arguments
prompts = [
"This movie is",
" Positive or Negative? The movie was",
" The sentiment of the movie was",
" The plot of the movie was really",
" The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"
# fit model
m = TreePromptClassifier(
checkpoint=checkpoint,
prompts=prompts,
verbalizer=verbalizer,
cache_prompt_features_dir=None, # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])
# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
np.mean(preds == dset_val['label'])) # -> 0.7
# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
print(i, prompt, '->', m.prompt_accs_[i]) # -> 0.65, 0.5, 0.5, 0.56, 0.51
# visualize decision tree
plot_tree(
m.clf_,
fontsize=10,
feature_names=m.feature_names_,
class_names=list(verbalizer.values()),
filled=True,
)
plt.show()
Reference:
@misc{ch2022augmenting,
title={Tree Prompting: Efficient Task Adaptation without Fine-Tuning},
year={2023},
archivePrefix={arXiv},
primaryClass={cs.AI}
}
Reproducing experiments
See the full code for reproducing all experiments in the paper at https://github.com/csinva/tree-prompt-experiments
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file treeprompt-0.0.1.tar.gz.
File metadata
- Download URL: treeprompt-0.0.1.tar.gz
- Upload date:
- Size: 11.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
76e8d74438c2d7cf030566d6f99e4d6e05c7be95358c54ef53b7da46ce27e0b6
|
|
| MD5 |
074e5dd09f6023431eedae61f4284218
|
|
| BLAKE2b-256 |
717a4aeabc06fb6dd61487eb89914eb507ac43aeab93e06dc5e4cd5ae9503237
|
File details
Details for the file treeprompt-0.0.1-py3-none-any.whl.
File metadata
- Download URL: treeprompt-0.0.1-py3-none-any.whl
- Upload date:
- Size: 11.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a0ccb885afb316917e70156ee7672b93c789c47da0200b4884c3a973bb6237e2
|
|
| MD5 |
8cc763eb2b0ee8141869e308eb90eaff
|
|
| BLAKE2b-256 |
d1b06652aa08edb12570f11af7b00cbe90b4d9abfdc3f989d5dfa235a43011d2
|