Skip to main content

Extract token-level probabilities from LLMs for classification-type outputs.

Project description

TokenProbs

Extract token-level probability scores from generative language models (GLMs) without fine-tuning. Often times, it is relevent to request probability assessment to binary or multi-class outcomes. GLMs are not well-suited for this task. Instead, use LogitExtractor to obtain label probabilities without fine-tuning.

Installation

Install with pip:

conda create -n TokenProbs python=3.9
conda activate TokenProbs
pip3 install TokenProbs 

Install via Github Repository:

conda create -n TokenProbs python=3.9
conda activate TokenProbs

git clone https://github.com/francescoafabozzi/TokenProbs.git
cd TokenProbs
pip3 install -e . # Install in editable mode 

Troubling Shooting

If recieving CUDA Setup failed despite GPU being available. error, identify the location of the cuda driver, typically found under /usr/local/ and input the following commands via the command line. The example below shows this for cuda-12.3.:

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3 # change 12.3 to appropriate location
export BNB_CUDA_VERSION=123 # 123 (i.e., 12.3) also needs to be changed

Usage

from TokenProbs import LogitExtractor

extractor = LogitExtractor(
    model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
    quantization="8bit" # None = Full precision, "4bit" also suported
)

# Test sentence
sentence = "AAPL shares were up in morning trading, but closed even on the day."

# Prompt sentence
prompt = \
"""Instructions: What is the sentiment of this news article? Select from {positive/neutral/negative}.
\nInput: %text_input
Answer:"""

prompted_sentence = prompt.replace("%text_input",sentence)

# Provide tokens to extract (can be TokenIDs or strings)
pred_tokens = ['positive','neutral','negative']


# Extract normalized token probabilities
probabilities = extractor.logit_extraction(
    input_data = prompted_sentence,
    tokens = pred_tokens,
    batch_size=1
)

print(f"Probabilities: {probabilities}")
Probabilities: {'positive': 0.7, 'neutral': 0.2, 'negative': 0.1}

# Compare to text output
text_output = extractor.text_generation(input_data,batch_size=1)

Additional Features

LogitExtractor also provides functionality for applying Low-rank Adaptation (LoRA) fine-tuning tailored to extracting logit scores for next-token predictions.

Below is an example of fine-tuning Mistral on Financial Phrasebank, a financial sentiment classification dataset.

from datasets import load_dataset
from TokenProbs import LogitExtractor

# Load dataset
dataset = load_dataset("financial_phrasebank",'sentences_50agree')['train']
# Apply training and test split
dataset = dataset.train_test_split(seed=42)
train = dataset['train']

# Convert class labels to text
labels = [{0:'negative',1:'neutral',2:'positive'}[i] for i in train['label']]
# Get sentences 
prompted_sentences = [prompt.replace("%text_input",sent) for sent in train['sentence']]

# Add labels to prompted sentences
training_texts = [prompted_sentences[i] + labels[i] for i in range(len(labels))]

# Load model
extractor = LogitExtractor(
    model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
    quantization="8bit"
)

# Set up SFFTrainer
extractor.trainer_setup(
    train_ds = training_texts, #either a dataloader object or text list
    response_seq = "\nAnswer:", # Tells trainer to train only on text following "\nAnswer: "
    # Input can be text string or list of TokenIDs. Be careful, tokens can differ based on context.
    lora_alpha=16,
    lora_rank=32,
    lora_dropout=0.1
)
extractor.trainer.train()
# Push model to huggingface
extractor.trainer.model.push_to_hub('<HF_USERNAME>/<MODEL_NAME>')

# Load model later
trained_model = extractor(
    model_name = '<HF_USERNAME>/<MODEL_NAME>',
    quantization="8bit"
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

TokenProbs-1.0.2.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

TokenProbs-1.0.2-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file TokenProbs-1.0.2.tar.gz.

File metadata

  • Download URL: TokenProbs-1.0.2.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for TokenProbs-1.0.2.tar.gz
Algorithm Hash digest
SHA256 b4d61f66769e2d973f053426ad57efc40db3532eea0fe846f061c1e10b4ec381
MD5 325ff318727c566585374603d5d63be8
BLAKE2b-256 ceb7020aa0ae2de3553e24db5838ed35d8bda994e2716e8f4fa72249ebdd8c61

See more details on using hashes here.

File details

Details for the file TokenProbs-1.0.2-py3-none-any.whl.

File metadata

  • Download URL: TokenProbs-1.0.2-py3-none-any.whl
  • Upload date:
  • Size: 6.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for TokenProbs-1.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 d4e727c8e2683eb7597b387c867f6add0944ee9cdfbf624265f277ff4edb9f46
MD5 774df767a870c11c365741daaa0b0f39
BLAKE2b-256 fe4c5019341d2d79d18a8bebc2eaee8c962451ce5b96b4da4871c3b21966fbc3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page