Skip to main content

Transformer/LLM-based zero and few-shot classification in scikit-learn pipelines

Project description

stormtrooper


Zero/few shot learning components for scikit-learn pipelines with large-language models and transformers.

Documentation

New in 1.0.0

Trooper

The brand new Trooper interface allows you not to have to specify what model type you wish to use. Stormtrooper will automatically detect the model type from the specified name.

from stormtrooper import Trooper

# This loads a setfit model
model = Trooper("all-MiniLM-L6-v2")

# This loads an OpenAI model
model = Trooper("gpt-4")

# This loads a Text2Text model
model = Trooper("google/flan-t5-base")

Unified zero and few-shot classification

You no longer have to specify whether a model should be a few or a zero-shot classifier when initialising it. If you do not pass any training examples, it will be automatically assumed that the model should be zero-shot.

# This is a zero-shot model
model.fit(None, ["dog", "cat"])

# This is a few-shot model
model.fit(["he was a good boy", "just lay down on my laptop"], ["dog", "cat"])

Model types

You can use all sorts of transformer models for few and zero-shot classification in Stormtrooper.

  1. Instruction fine-tuned generative models, e.g. Trooper("HuggingFaceH4/zephyr-7b-beta")
  2. Encoder models with SetFit, e.g. Trooper("all-MiniLM-L6-v2")
  3. Text2Text models e.g. Trooper("google/flan-t5-base")
  4. OpenAI models e.g. Trooper("gpt-4")
  5. NLI models e.g. Trooper("facebook/bart-large-mnli")

Example usage

Find more in our docs.

pip install stormtrooper
from stormtrooper import Trooper

class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
    "God came down to earth to save us.",
    "A new nebula was recently discovered in the proximity of the Oort cloud."
]
new_texts = ["God bless the reailway workers", "The frigate is ready to launch from the spaceport"]

# Zero-shot classification
model = Trooper("google/flan-t5-base")
model.fit(None, class_labels)
model.predict(new_texts)
# ["atheism/christianity", "astronomy/space"]

# Few-shot classification
model = Trooper("google/flan-t5-base")
model.fit(example_texts, class_labels)
model.predict(new_texts)
# ["atheism/christianity", "astronomy/space"]

Fuzzy Matching

Generative and text2text models by default will fuzzy match results to the closest class label, you can disable this behavior by specifying fuzzy_match=False.

If you want fuzzy matching speedup, you should install python-Levenshtein.

Inference on GPU

From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:

classifier = Trooper("all-MiniLM-L6-v2", device="cuda:0")

Inference on multiple GPUs

You can run a model on multiple devices in order of device priority GPU -> CPU + Ram -> Disk and on multiple devices by using the device_map argument. Note that this only works with text2text and generative models.

model = Trooper("HuggingFaceH4/zephyr-7b-beta", device_map="auto")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

stormtrooper-1.0.0.tar.gz (12.5 kB view details)

Uploaded Source

Built Distribution

stormtrooper-1.0.0-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file stormtrooper-1.0.0.tar.gz.

File metadata

  • Download URL: stormtrooper-1.0.0.tar.gz
  • Upload date:
  • Size: 12.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.5 Linux/5.15.0-119-generic

File hashes

Hashes for stormtrooper-1.0.0.tar.gz
Algorithm Hash digest
SHA256 5e981aa59dd8bd3ca8b03def6f844bdd66ef204ce027f142981ace5b03c186d4
MD5 d094fe977f72c365355eda68dd0df583
BLAKE2b-256 b45feb29642c05e45036c1a9a345f56499b401e9656884763ce74d2035cbff68

See more details on using hashes here.

File details

Details for the file stormtrooper-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: stormtrooper-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.7.1 CPython/3.11.5 Linux/5.15.0-119-generic

File hashes

Hashes for stormtrooper-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2359c666daaea41e36c3029c19286ce06bbfee5b413f92707add6bda9c10bc34
MD5 f1d081c8ccb1a646e76dae68bc6f8bc8
BLAKE2b-256 20a766ebd7d3c00092539f06ca5f9a0db2d0be6c21391a3294fff8d66e495697

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page