Skip to main content

Transformer-based zero and few-shot classification in scikit-learn pipelines

Project description

stormtrooper


Transformer-based zero/few shot learning components for scikit-learn pipelines.

Example

pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
    "God came down to earth to save us.",
    "A new nebula was recently discovered in the proximity of the Oort cloud."
]

Zero-shot learning

For zero-shot learning you can use zero-shot models:

from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)

Generative models (GPT, Llama):

from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)

Text2Text models (T5): If you are running low on resources I would personally recommend T5.

from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)

assert list(predictions) == ["atheism/christianity", "astronomy/space"]

Few-Shot Learning

For few-shot tasks you can only use Generative and Text2Text (aka. promptable) models.

from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier

classifier = Text2TextFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])

assert list(predictions) == ["atheism/christianity"]

Fuzzy Matching

Models by default will fuzzy match results to the closest class label, you can disable this behavior by specifying fuzzy_match=False.

If you want fuzzy matching speedup, you should install python-Levenshtein.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stormtrooper-0.2.0.tar.gz (7.3 kB view details)

Uploaded Source

Built Distribution

stormtrooper-0.2.0-py3-none-any.whl (9.5 kB view details)

Uploaded Python 3

File details

Details for the file stormtrooper-0.2.0.tar.gz.

File metadata

  • Download URL: stormtrooper-0.2.0.tar.gz
  • Upload date:
  • Size: 7.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-78-generic

File hashes

Hashes for stormtrooper-0.2.0.tar.gz
Algorithm Hash digest
SHA256 9c6a67a98ec4428d14de7dca3face11d2ed6ff624a4f11741fc9c6b9a31bd999
MD5 79b3fad87c90e372b16e99be2e1e3ae3
BLAKE2b-256 886fb37f1049379fe9969ea26975ec9d55266eb22e59a84d1f060acfbca94d3d

See more details on using hashes here.

File details

Details for the file stormtrooper-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: stormtrooper-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 9.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-78-generic

File hashes

Hashes for stormtrooper-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b659aa32d2a84f33f435ac1beda2b0a35fbfd2661ce79edeec717335be7ca420
MD5 3d15beb56de194ee509688751456d51a
BLAKE2b-256 8bd5520d0a880242b0884a94b70f3c5c27e1ed9397ab2dec80538aac5d344781

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page