Transformer/LLM-based zero and few-shot classification in scikit-learn pipelines
Project description
stormtrooper
Transformer-based zero/few shot learning components for scikit-learn pipelines.
New in version 0.4.0 :fire:
- You can now use OpenAI's chat models with blazing fast :zap: async inference.
New in version 0.3.0 🌟
- SetFit is now part of the library and can be used in scikit-learn workflows.
Example
pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
Zero-shot learning
For zero-shot learning you can use zero-shot models:
from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)
Generative models (GPT, Llama):
from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)
Text2Text models (T5): If you are running low on resources I would personally recommend T5.
from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
OpenAI models: You can now use OpenAI's chat LLMs in stormtrooper workflows.
from stormtrooper import OpenAIZeroShotClassifier
classifier = OpenAIZeroShotClassifier("gpt-4").fit(None, class_labels)
predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
Few-Shot Learning
For few-shot tasks you can only use Generative, Text2Text, OpenAI (aka. promptable) or SetFit models.
from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier, SetFitFewShotClassifier
classifier = SetFitFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])
assert list(predictions) == ["atheism/christianity"]
Fuzzy Matching
Generative and text2text models by default will fuzzy match results to the closest class label, you can disable this behavior
by specifying fuzzy_match=False
.
If you want fuzzy matching speedup, you should install python-Levenshtein
.
Inference on GPU
From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:
classifier = Text2TextZeroShotClassifier(device="cuda:0")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for stormtrooper-0.4.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c0e5e9dab94ceb68ac8272c04e81e51af18b117d98617dfd46273e6e5307af2d |
|
MD5 | c2bc3552f6e10ecc04655ec5405ee93e |
|
BLAKE2b-256 | 3c3a823e367f18fe42b78366795bb451d50b0192c96012c293cc9d40cad22098 |