Skip to main content

Transformer/LLM-based zero and few-shot classification in scikit-learn pipelines

Project description

stormtrooper


Transformer-based zero/few shot learning components for scikit-learn pipelines.

Documentation

New in version 0.4.0 :fire:

  • You can now use OpenAI's chat models with blazing fast :zap: async inference.

New in version 0.3.0 🌟

  • SetFit is now part of the library and can be used in scikit-learn workflows.

Example

pip install stormtrooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
    "God came down to earth to save us.",
    "A new nebula was recently discovered in the proximity of the Oort cloud."
]

Zero-shot learning

For zero-shot learning you can use zero-shot models:

from stormtrooper import ZeroShotClassifier
classifier = ZeroShotClassifier().fit(None, class_labels)

Generative models (GPT, Llama):

from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)

Text2Text models (T5): If you are running low on resources I would personally recommend T5.

from stormtrooper import Text2TextZeroShotClassifier
# You can define a custom prompt, but a default one is available
prompt = "..."
classifier =Text2TextZeroShotClassifier(prompt=prompt).fit(None, class_labels)
predictions = classifier.predict(example_texts)

assert list(predictions) == ["atheism/christianity", "astronomy/space"]

OpenAI models: You can now use OpenAI's chat LLMs in stormtrooper workflows.

from stormtrooper import OpenAIZeroShotClassifier

classifier = OpenAIZeroShotClassifier("gpt-4").fit(None, class_labels)
predictions = classifier.predict(example_texts)

assert list(predictions) == ["atheism/christianity", "astronomy/space"]

Few-Shot Learning

For few-shot tasks you can only use Generative, Text2Text, OpenAI (aka. promptable) or SetFit models.

from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier, SetFitFewShotClassifier

classifier = SetFitFewShotClassifier().fit(example_texts, class_labels)
predictions = model.predict(["Calvinists believe in predestination."])

assert list(predictions) == ["atheism/christianity"]

Fuzzy Matching

Generative and text2text models by default will fuzzy match results to the closest class label, you can disable this behavior by specifying fuzzy_match=False.

If you want fuzzy matching speedup, you should install python-Levenshtein.

Inference on GPU

From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:

classifier = Text2TextZeroShotClassifier(device="cuda:0")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stormtrooper-0.4.0.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

stormtrooper-0.4.0-py3-none-any.whl (19.0 kB view details)

Uploaded Python 3

File details

Details for the file stormtrooper-0.4.0.tar.gz.

File metadata

  • Download URL: stormtrooper-0.4.0.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-83-generic

File hashes

Hashes for stormtrooper-0.4.0.tar.gz
Algorithm Hash digest
SHA256 ef0a5b87586127b495542afc2eeb94867fe8a8c05e3364ab5220fa2c618bb7b1
MD5 8d901697ff1f67fb6e359504ed8540c4
BLAKE2b-256 7724de7adf1f325c3081d4bc7727e2291f9c34b5c449d9462b2debac71b2b584

See more details on using hashes here.

File details

Details for the file stormtrooper-0.4.0-py3-none-any.whl.

File metadata

  • Download URL: stormtrooper-0.4.0-py3-none-any.whl
  • Upload date:
  • Size: 19.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.10.8 Linux/5.15.0-83-generic

File hashes

Hashes for stormtrooper-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 c0e5e9dab94ceb68ac8272c04e81e51af18b117d98617dfd46273e6e5307af2d
MD5 c2bc3552f6e10ecc04655ec5405ee93e
BLAKE2b-256 3c3a823e367f18fe42b78366795bb451d50b0192c96012c293cc9d40cad22098

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page