Extract token-level probabilities from LLMs for classification-type outputs.
Project description
TokenProbs
Extract token-level probability scores from generative language models (GLMs) without fine-tuning. Often times, it is relevent to request probability assessment to binary or multi-class outcomes. GLMs are not well-suited for this task. Instead, use LogitExtractor
to obtain label probabilities without fine-tuning.
Installation
conda create -n GenCasting python=3.9
pip3 install GenCasting
Troubling Shooting
If recieving CUDA Setup failed despite GPU being available.
Identify the location of the cuda driver, typically found under /usr/local/
and input the following commands via the command line. The example below shows this for cuda-12.3.:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3 # change 12.3 to appropriate location
export BNB_CUDA_VERSION=123 # 123 (i.e., 12.3) also needs to be changed
Usage
from TokenProbs import LogitExtractor
extractor = LogitExtractor(
model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
quantization="8bit" # None = Full precision, "4bit" also suported
)
# Test sentence
sentence = "AAPL shares were up in morning trading, but closed even on the day."
# Prompt sentence
prompt = \
"""Instructions: What is the sentiment of this news article? Select from {positive/neutral/negative}.
\nInput: %text_input
Answer:"""
prompted_sentence = prompt.replace("%text_input",sentence)
# Provide tokens to extract (can be TokenIDs or strings)
pred_tokens = ['positive','neutral','negative']
# Extract normalized token probabilities
probabilities = extractor.logit_extraction(
input_data = prompted_sentence,
tokens = pred_tokens,
batch_size=1
)
print(f"Probabilities: {probabilities}")
Probabilities: {'positive': 0.7, 'neutral': 0.2, 'negative': 0.1}
# Compare to text output
text_output = extractor.text_generation(input_data,batch_size=1)
Additional Features
LogitExtractor
also provides functionality for applying Low-rank Adaptation (LoRA) fine-tuning tailored to extracting logit scores for next-token predictions.
from datasets import load_dataset
from TokenProbs import LogitExtractor
# Load dataset
dataset = load_dataset("financial_phrasebank",'sentences_50agree')['train']
# Apply training and test split
dataset = dataset.train_test_split(seed=42)
train = dataset['train']
# Convert class labels to text
labels = [{0:'negative',1:'neutral',2:'positive'}[i] for i in train['label']]
# Get sentences
prompted_sentences = [prompt.replace("%text_input",sent) for sent in train['sentence']]
# Add labels to prompted sentences
training_texts = [prompted_sentences[i] + labels[i] for i in range(len(labels))]
# Load model
extractor = LogitExtractor(
model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
quantization="8bit"
)
# Set up SFFTrainer
extractor.trainer_setup(
train_ds = training_texts, #either a dataloader object or text list
response_seq = "\nAnswer:", # Tells trainer to train only on text following "\nAnswer: "
# Input can be text string or list of TokenIDs. Be careful, tokens can differ based on context.
lora_alpha=16,
lora_rank=32,
lora_dropout=0.1
)
extractor.trainer.train()
# Push model to huggingface
extractor.trainer.model.push_to_hub('<HF_USERNAME>/<MODEL_NAME>')
# Load model later
trained_model = extractor(
model_name = '<HF_USERNAME>/<MODEL_NAME>',
quantization="8bit"
)
Examples
Coming soon.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file TokenProbs-1.0.0.tar.gz
.
File metadata
- Download URL: TokenProbs-1.0.0.tar.gz
- Upload date:
- Size: 5.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | c977d927794b7a688f6ae9765be0a22e0593fca38f603f9ca2bb7a36a167b45d |
|
MD5 | daaafd72133853e63b1f5b3cf5c02b37 |
|
BLAKE2b-256 | 99d6ce1ce457ff2be06444bcc5d4d979885d3c50ca5df0423c8182c2d523e31b |
File details
Details for the file TokenProbs-1.0.0-py3-none-any.whl
.
File metadata
- Download URL: TokenProbs-1.0.0-py3-none-any.whl
- Upload date:
- Size: 6.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.10
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1c74ebbfb8c7fe3fc5657243481e70e7090acf3bad241cccaa211e6d1bb9efd9 |
|
MD5 | 967782e46d905ac94b57ac7beba9ed33 |
|
BLAKE2b-256 | 3f1f12058bfb856572943855b6065a7d7b54244c13525501e7f1164e569ca5e1 |