Skip to main content

Tree prompting

Project description

Tree Prompting

Tree Prompting: Efficient Task Adaptation without Fine-Tuning, code for the Tree-prompt paper.

Tree Prompting uses training examples to learn a tree of prompts to make a classification, yielding higher accuracy and better efficiency that baseline ensembles.

Quickstart

Installation: pip install treeprompt (or clone this repo and pip install -e .)

from treeprompt.treeprompt import TreePromptClassifier
import datasets
import numpy as np
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

# set up data
rng = np.random.default_rng(seed=42)
dset_train = datasets.load_dataset('rotten_tomatoes')['train']
dset_train = dset_train.select(rng.choice(
    len(dset_train), size=100, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(rng.choice(
    len(dset_val), size=100, replace=False))

# set up arguments
prompts = [
    "This movie is",
    " Positive or Negative? The movie was",
    " The sentiment of the movie was",
    " The plot of the movie was really",
    " The acting in the movie was",
]
verbalizer = {0: " Negative.", 1: " Positive."}
checkpoint = "gpt2"

# fit model
m = TreePromptClassifier(
    checkpoint=checkpoint,
    prompts=prompts,
    verbalizer=verbalizer,
    cache_prompt_features_dir=None,  # 'cache_prompt_features_dir/gp2',
)
m.fit(dset_train["text"], dset_train["label"])


# compute accuracy
preds = m.predict(dset_val['text'])
print('\nTree-Prompt acc (val) ->',
      np.mean(preds == dset_val['label']))  # -> 0.7

# compare to accuracy for individual prompts
for i, prompt in enumerate(prompts):
    print(i, prompt, '->', m.prompt_accs_[i])  # -> 0.65, 0.5, 0.5, 0.56, 0.51

# visualize decision tree
plot_tree(
    m.clf_,
    fontsize=10,
    feature_names=m.feature_names_,
    class_names=list(verbalizer.values()),
    filled=True,
)
plt.show()

Reference:

@misc{ch2022augmenting,
    title={Tree Prompting: Efficient Task Adaptation without Fine-Tuning},
    year={2023},
    archivePrefix={arXiv},
    primaryClass={cs.AI}
}

Reproducing experiments

See the full code for reproducing all experiments in the paper at https://github.com/csinva/tree-prompt-experiments

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treeprompt-0.0.1.tar.gz (11.1 kB view details)

Uploaded Source

Built Distribution

treeprompt-0.0.1-py3-none-any.whl (11.3 kB view details)

Uploaded Python 3

File details

Details for the file treeprompt-0.0.1.tar.gz.

File metadata

  • Download URL: treeprompt-0.0.1.tar.gz
  • Upload date:
  • Size: 11.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for treeprompt-0.0.1.tar.gz
Algorithm Hash digest
SHA256 76e8d74438c2d7cf030566d6f99e4d6e05c7be95358c54ef53b7da46ce27e0b6
MD5 074e5dd09f6023431eedae61f4284218
BLAKE2b-256 717a4aeabc06fb6dd61487eb89914eb507ac43aeab93e06dc5e4cd5ae9503237

See more details on using hashes here.

File details

Details for the file treeprompt-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: treeprompt-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 11.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.5

File hashes

Hashes for treeprompt-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a0ccb885afb316917e70156ee7672b93c789c47da0200b4884c3a973bb6237e2
MD5 8cc763eb2b0ee8141869e308eb90eaff
BLAKE2b-256 d1b06652aa08edb12570f11af7b00cbe90b4d9abfdc3f989d5dfa235a43011d2

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page