Transformer-based zero and few-shot classification in scikit-learn pipelines
Project description
stormtrooper
Transformer-based zero/few shot learning components for scikit-learn pipelines.
Example
pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
Zero-shot learning
For zero-shot learning you can use zero-shot models:
from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)
Generative models (GPT, Llama):
from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)
Text2Text models (T5): If you are running low on resources I would personally recommend T5.
from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
Few-Shot Learning
For few-shot tasks you can only use Generative and Text2Text (aka. promptable) models.
from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier
classifier = Text2TextFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])
assert list(predictions) == ["atheism/christianity"]
Fuzzy Matching
Models by default will fuzzy match results to the closest class label, you can disable this behavior
by specifying fuzzy_match=False
.
If you want fuzzy matching speedup, you should install python-Levenshtein
.
Inference on GPU
From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:
classifier = Text2TextZeroShotClassifier(device="cuda:0")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for stormtrooper-0.2.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e16f97b6ab9d3bdb8c86eb4b27f1ee7401f9f3bb9ddfe928449e8063e87d6b8e |
|
MD5 | e3e99b8cf510c953e37ee560f4cc2c21 |
|
BLAKE2b-256 | 5025b33da1c7c768f58cabd965ff273e853afe2f3bfb20b8fa7e051987d8a6e6 |