Transformer-based zero and few-shot classification in scikit-learn pipelines
Project description
stormtrooper
Transformer-based zero/few shot learning components for scikit-learn pipelines.
Example
pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
Zero-shot learning
For zero-shot learning you can use zero-shot models:
from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)
Generative models (GPT, Llama):
from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)
Text2Text models (T5): If you are running low on resources I would personally recommend T5.
from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
Few-Shot Learning
For few-shot tasks you can only use Generative and Text2Text (aka. promptable) models.
from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier
classifier = Text2TextFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])
assert list(predictions) == ["atheism/christianity"]
Fuzzy Matching
Models by default will fuzzy match results to the closest class label, you can disable this behavior
by specifying fuzzy_match=False
.
If you want fuzzy matching speedup, you should install python-Levenshtein
.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
stormtrooper-0.2.1.tar.gz
(7.4 kB
view hashes)
Built Distribution
Close
Hashes for stormtrooper-0.2.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2a5d96745956f97deef6bbc2ec73ab5d5fd54f5c05f8408371962c6edb20c8b6 |
|
MD5 | 46fa7d3ffe0829809bb7dde2fb89ca06 |
|
BLAKE2b-256 | ce6f26829d229751d725a72f3790f82e009d72a16557836da925cc1d848496cf |