Transformer/LLM-based zero and few-shot classification in scikit-learn pipelines
Project description
stormtrooper
Zero/few shot learning components for scikit-learn pipelines with large-language models and transformers.
New in 1.0.0
Trooper
The brand new Trooper
interface allows you not to have to specify what model type you wish to use.
Stormtrooper will automatically detect the model type from the specified name.
from stormtrooper import Trooper
# This loads a setfit model
model = Trooper("all-MiniLM-L6-v2")
# This loads an OpenAI model
model = Trooper("gpt-4")
# This loads a Text2Text model
model = Trooper("google/flan-t5-base")
Unified zero and few-shot classification
You no longer have to specify whether a model should be a few or a zero-shot classifier when initialising it. If you do not pass any training examples, it will be automatically assumed that the model should be zero-shot.
# This is a zero-shot model
model.fit(None, ["dog", "cat"])
# This is a few-shot model
model.fit(["he was a good boy", "just lay down on my laptop"], ["dog", "cat"])
Model types
You can use all sorts of transformer models for few and zero-shot classification in Stormtrooper.
- Instruction fine-tuned generative models, e.g.
Trooper("HuggingFaceH4/zephyr-7b-beta")
- Encoder models with SetFit, e.g.
Trooper("all-MiniLM-L6-v2")
- Text2Text models e.g.
Trooper("google/flan-t5-base")
- OpenAI models e.g.
Trooper("gpt-4")
- NLI models e.g.
Trooper("facebook/bart-large-mnli")
Example usage
Find more in our docs.
pip install stormtrooper
from stormtrooper import Trooper
class_labels = ["atheism/christianity", "astronomy/space"]
example_texts = [
"God came down to earth to save us.",
"A new nebula was recently discovered in the proximity of the Oort cloud."
]
new_texts = ["God bless the reailway workers", "The frigate is ready to launch from the spaceport"]
# Zero-shot classification
model = Trooper("google/flan-t5-base")
model.fit(None, class_labels)
model.predict(new_texts)
# ["atheism/christianity", "astronomy/space"]
# Few-shot classification
model = Trooper("google/flan-t5-base")
model.fit(example_texts, class_labels)
model.predict(new_texts)
# ["atheism/christianity", "astronomy/space"]
Fuzzy Matching
Generative and text2text models by default will fuzzy match results to the closest class label, you can disable this behavior
by specifying fuzzy_match=False
.
If you want fuzzy matching speedup, you should install python-Levenshtein
.
Inference on GPU
From version 0.2.2 you can run models on GPU. You can specify the device when initializing a model:
classifier = Trooper("all-MiniLM-L6-v2", device="cuda:0")
Inference on multiple GPUs
You can run a model on multiple devices in order of device priority GPU -> CPU + Ram -> Disk
and on multiple devices by using the device_map
argument.
Note that this only works with text2text and generative models.
model = Trooper("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file stormtrooper-1.0.0.tar.gz
.
File metadata
- Download URL: stormtrooper-1.0.0.tar.gz
- Upload date:
- Size: 12.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.5 Linux/5.15.0-119-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5e981aa59dd8bd3ca8b03def6f844bdd66ef204ce027f142981ace5b03c186d4 |
|
MD5 | d094fe977f72c365355eda68dd0df583 |
|
BLAKE2b-256 | b45feb29642c05e45036c1a9a345f56499b401e9656884763ce74d2035cbff68 |
File details
Details for the file stormtrooper-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: stormtrooper-1.0.0-py3-none-any.whl
- Upload date:
- Size: 17.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.7.1 CPython/3.11.5 Linux/5.15.0-119-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2359c666daaea41e36c3029c19286ce06bbfee5b413f92707add6bda9c10bc34 |
|
MD5 | f1d081c8ccb1a646e76dae68bc6f8bc8 |
|
BLAKE2b-256 | 20a766ebd7d3c00092539f06ca5f9a0db2d0be6c21391a3294fff8d66e495697 |