Transformer-based zero and few-shot classification in scikit-learn pipelines
Project description
stormtrooper
Transformer-based zero/few shot learning components for scikit-learn pipelines.
New in version 0.3.0 🌟
- SetFit is now part of the library and can be used in scikit-learn workflows.
Example
pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
Zero-shot learning
For zero-shot learning you can use zero-shot models:
from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)
Generative models (GPT, Llama):
from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)
Text2Text models (T5): If you are running low on resources I would personally recommend T5.
from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
Few-Shot Learning
For few-shot tasks you can only use Generative, Text2Text (aka. promptable) or SetFit models.
from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier
from stormtrooper.setfit import SetFitFewShotClassifier
classifier = SetFitFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])
assert list(predictions) == ["atheism/christianity"]
Fuzzy Matching
Generative and text2text models by default will fuzzy match results to the closest class label, you can disable this behavior
by specifying fuzzy_match=False
.
If you want fuzzy matching speedup, you should install python-Levenshtein
.
Inference on GPU
From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:
classifier = Text2TextZeroShotClassifier(device="cuda:0")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file stormtrooper-0.3.1.tar.gz
.
File metadata
- Download URL: stormtrooper-0.3.1.tar.gz
- Upload date:
- Size: 8.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-79-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74d6feaa0c810b3f34b29cd51fe461a93dd9b65c8ff031b88d487cf58a64e009 |
|
MD5 | 5409023850e7e60bd7b7e3a3f67193ab |
|
BLAKE2b-256 | 18937d85daa65c6549582336e423910563ae0cc99b6026c5aa90a430a28ecdd9 |
File details
Details for the file stormtrooper-0.3.1-py3-none-any.whl
.
File metadata
- Download URL: stormtrooper-0.3.1-py3-none-any.whl
- Upload date:
- Size: 11.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-79-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4957a7ca8142325a266472cd44794f43f8970f5fc26f5feaace32ada29d0e99d |
|
MD5 | c8e8d43d0ed48e14892418fe5f9a5c8d |
|
BLAKE2b-256 | d46e2cabd1330664e4ccb8ef20cf943d42d7fef6ca50f297f7d67e5c2ca24655 |