Zero-shot text classification using autoregressive language models.
Project description
CAPPr: zero-shot text classification using autoregressive language models
Perform zero-shot text classification by estimating the probability that an inputted completion comes after an inputted prompt. Hence the name:
Completion
After
Prompt
Probability
The method is fleshed out in my question on Cross Validated.
Usage
Use a model from the OpenAI API
Specifically, this model must be compatible with the /v1/completions endpoint.
from cappr.openai.classify import predict
prompt = """
Tweet about a movie: "Oppenheimer was pretty good. But 3 hrs...cmon Nolan."
This tweet contains the following criticism:
""".strip("\n")
class_names = ("bad message", "too long", "unfunny")
preds = predict(
prompts=[prompt],
completions=class_names,
model="text-ada-001",
)
print(preds)
# ['too long']
Notice that the completions can contain many tokens.
Extract the final answer from a step-by-step completion
Step-by-step and chain-of-thought prompts are highly effective ways to get an LLM to "reason" about more complex tasks. But if you need a structured output, a step-by-step completion is unwieldy. Use CAPPr to extract the final answer from these types of completions, given a list of possible answers.
See this idea in action here in the docs. CAPPr is 100% guaranteed to return an output from the list of answers.
Use a model from the HuggingFace model hub
Specifically, this model must be able to be loaded using
transformers.AutoModelForCausalLM.from_pretrained
.
from transformers import AutoModelForCausalLM, AutoTokenizer
from cappr.huggingface.classify import predict
prompt = "Which planet is closer to the Sun: Mercury or Earth?"
class_names = ("Mercury", "Earth")
# load model and tokenizer
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
preds = predict(
prompts=[prompt],
completions=class_names,
model_and_tokenizer=(model, tokenizer),
)
print(preds)
# ['Mercury']
For an example with Llama 2, see the notebook
demos/llama2.ipynb
.
So far, CAPPr has been tested for correctness on the following models:
- GPT-2
- GPT-J
- Llama
- Llama 2 (chat and raw).
Run in batches
Let's use huggingface
for this example cuz it's free. And let's predict probabilities
instead of the class.
from transformers import AutoModelForCausalLM, AutoTokenizer
from cappr.huggingface.classify import predict_proba
prompts = [
"Stephen Curry is a",
"Martina Navratilova was a",
"Dexter, from the TV Series Dexter's Laboratory, is a",
"LeBron James is a",
]
# each of the prompts could be completed with one of these:
class_names = ("basketball player", "tennis player", "scientist")
prior = ( 1/6, 1/6, 2/3 )
# say I expect most of my data to have scientists
# load model and tokenizer
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pred_probs = predict_proba(
prompts=prompts,
completions=class_names,
model_and_tokenizer=(model, tokenizer),
batch_size=32, # whatever fits on your CPU/GPU
prior=prior,
)
# pred_probs[i,j] = probability that prompts[i] is classified as class_names[j]
print(pred_probs.round(1))
# [[0.5 0.3 0.2]
# [0.3 0.6 0.2]
# [0.1 0.1 0.8]
# [0.8 0.2 0. ]]
# for each prompt, which completion is most likely?
pred_class_idxs = pred_probs.argmax(axis=1)
print([class_names[pred_class_idx] for pred_class_idx in pred_class_idxs])
# ['basketball player',
# 'tennis player',
# 'scientist',
# 'basketball player']
Run in batches, where each prompt has a different set of possible completions
Again, let's use huggingface
to predict probabilities.
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from cappr import Example
from cappr.huggingface.classify import predict_proba_examples
examples = [
Example(
prompt="Jodie Foster played",
completions=("Clarice Starling", "Trinity in The Matrix"),
),
Example(
prompt="Batman, from Batman: The Animated Series, was played by",
completions=("Pete Holmes", "Kevin Conroy", "Spongebob!"),
prior= ( 1/3 , 2/3 , 0 ),
),
]
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pred_probs = predict_proba_examples(examples, model_and_tokenizer=(model, tokenizer))
# pred_probs[i][j] = probability that examples[i].prompt is classified as
# examples[i].completions[j]
print([example_pred_probs.round(2) for example_pred_probs in pred_probs])
# [array([0.7, 0.3]),
# array([0.03, 0.97, 0. ])]
# for each example, which completion is most likely?
pred_class_idxs = [np.argmax(example_pred_probs) for example_pred_probs in pred_probs]
print(
[
example.completions[pred_class_idx]
for example, pred_class_idx in zip(examples, pred_class_idxs)
]
)
# ['Clarice Starling',
# 'Kevin Conroy']
More examples are linked here in the documentation.
See
demos/superglue/copa.ipynb
for a demonstration of a slightly harder classification task.
Documentation
Setup
If you intend on using OpenAI models, sign up for the OpenAI API
here, and then set the environment variable
OPENAI_API_KEY
. For zero-shot classification, OpenAI models are currently far ahead of
others. But using them will cost ya 💰!
Install with pip
:
pip install cappr
(Optional) Install requirements for HuggingFace models
pip install cappr[hf]
(Optional) Install requirements for running
demos
pip install cappr[demos]
Motivation
Create a more usable zero-shot text classification interface than classification via sampling (CVS).
Short
With CVS, your job is to write up your classification task in a prompt
string, and
then write custom code to post-process arbitrary completion
/output strings.
With CAPPr, your job starts and stops at writing up your classification task as a
{prompt}{end_of_prompt}{completion}
string.
Long
Please see this page of the documentation.
Unstudied
I'm curious to see how much easier estimation/discrimination is than generation. In
demos/superglue/copa.ipynb
,
CVS using OpenAI's text-curie-001
is less than 50% accurate, while CAPPr is 80%
accurate.
Honest
Keep myself busy
Results
Statistical performance
Not too shabby. TODO: summary table comparing CVS vs. CAPPr vs. few-shot methods like SetFit and PET.
Computational performance
One concern was that CAPPr requires as many model()
calls as there are classes. But in
the CAPPr scheme, we can simply cache each attention block's keys and values for the
prompts. This feature is already supported by AutoModelForCausalLM
s. See this
code for
the implementation. Note that this caching is not implemented for OpenAI models, as I
can't control their backend. This means that when running cappr.openai
functions,
you'll be on the cappr (no cache) line :-(
Figure 1: COPA dataset, repeating the
choices to simulate multi-class classification tasks. GPT-2
(small) was run on a Tesla K80 GPU (whatever was free in
Google Colab in March 2023). 96 classification inputs were processed in batches of size
32. Each point in the graph is a median of 5 runs. For classification via sampling
(CVS), exactly 4 tokens were generated for each prompt, which is the number of tokens in
'\n\nAnswer A'
. 1-token times are also shown. But for COPA (and other multiple-choice
style prompts), that may result in lower zero-shot accuracy, as most of the sampled
choices come after the first token.
Related work
There are many papers where averaging token log-probabilities is a useful subroutine. Here are some papers which focus on this idea.
While benchmarking this method on the Winograd Schema Challenge, I found that this paper is very similar:
Trinh, Trieu H., and Quoc V. Le. "A simple method for commonsense reasoning." arXiv preprint arXiv:1806.02847 (2018).
PET with multiple masks also aggregates token probabilities to do prompt-completion classification, but these probabilities are assumed to come from masked language models like BERT.
Schick, Timo, and Hinrich Schütze. "It's not just size that matters: Small language models are also few-shot learners." arXiv preprint arXiv:2009.07118 (2020).
Contributing
TODO
Local development
Setup
-
Create a new Python 3.8+ environment using venv. Activate it
-
Clone the repo (or fork it and clone that)
git clone https://github.com/kddubey/cappr.git
-
cd to the repo and install this package in editable mode, along with development requirements (ensure your venv is activated)
python -m pip install -e .[dev]
VS code extensions for development
- autoDocstring. Use the numpy format.
- Set Python formatting to
black
. - Rewrap. Enable Auto Wrap.
Testing
pytest
Note that a few small, dummy model will be downloaded to your computer if you don't have them already.
Docs
To locally build docs (I'm on Windows lol), run
cd docs
make.bat html
To preview these docs, open docs/build/html/index.html
.
Docs are automatically built when code is merged to main.
Release
Bump the version, and then create a new release on GitHub. A new version of the package will then be automatically published on PyPI.
Todo
(**) = I'm currently working on this or will work on it really soon. Expect it in the next release or two.
Code
- Factor out the discount feature in
cappr.openai.classify.predict_proba
intocappr.utils.classify._predict_proba
- Small CPU speed-ups
- For constant-completions input, vectorize
cappr.utils.classify.agg_log_probs
- For
examples
input, if # completions per prompt is constant, vectorizecappr.utils.classify.posterior_prob
- For constant-completions input, vectorize
- HuggingFace
transformers.AutoModelForCausalLM
- Support as many of them as possible
- GPT-2
- GPT-J
- Llama
- Llama 2
- Llama 2 chat
- Vicuna
- PaLM
- T5
- If all completions are single-tokens, just run inference once (**)
- Optimize backend to enable greater scaling wrt # completions/classes
- Get it working on GPU, check that it's faster than sampling
- Get to the bottom of why it's slower w/o batching (**)
- Allow non-
' '
end_of_prompt
. I'm not sure how helpful that is. - Factor out repeated code b/t
classify
andclassify_no_cache
- Support Inference Endpoints?
- Support TensorFlow models?
- Support priming, as in: cache it. See backprompt
- Support as many of them as possible
- User conveniences (**)
- Make progress bars optional, since inference often isn't batched
- Accept string input and return string instead of list
- Factor out input checks (on prompts and completions)
- (for me) Auto-enforced code formatting b/c it's getting time-consuming
- Allow for multi-label classification
- Pass
normalize
as an argument to predict_proba functions - For
huggingface
, add note that you'll get faster results by passing all labels at once (assuming prompt is identical for each label)
- Pass
- Fill in missing or non-numpy docstrings
- Testing
- Test
cappr.huggingface.classify_no_cache
by comparing to results w/o batching! - For heavily quantized models, only test that pred probs are w/in 1e-2 atol
- Increase test cases
- Test input checks
- Test
cappr.openai.api
- Test
Research
Evaluate on more datasets, and understand its relative advantages and disadvantages vs other classification methods.
- (Llama2 + CAPPr) (Llama2 + CVS) vs (Llama2 chat + CAPPr) vs (Llama2 chat + CVS) (**)
- RAFT benchmark
- Zero-shot training scores
- Submit zero-shot test predictions
- Few-shot (priming) training scores
- Submit few-shot test predictions
- Create a user guide, build a table of results comparing competing approaches on statistical performance, cost, and computation
- Evaluate a CoT/SbS prompt -> CAPPr to pull out the answer
- Make a computational comparison to sampling
- Assume I have full freedom to decide how inference works. Demo w/ GPT-2. Process inputs in batches.
- Process inputs 1-by-1
- More SuperGLUE tasks?
- Calibration
- Is the prior actually effective? Downsample and see
- curves
- Finetune smaller, cheaper model and compare against zero-shot w/ davinci
- e.g., GPT-2 from huggingface,
text-ada-001
- Again, compare against sampling
- e.g., GPT-2 from huggingface,
- Evaluate a bigger model like GPT-J
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.