Skip to main content

Extract token-level probabilities from LLMs for classification-type outputs.

Project description

TokenProbs

Extract token-level probability scores from generative language models (GLMs) without fine-tuning. Often times, it is relevent to request probability assessment to binary or multi-class outcomes. GLMs are not well-suited for this task. Instead, use LogitExtractor to obtain label probabilities without fine-tuning.

Installation

conda create -n GenCasting python=3.9
pip3 install GenCasting 

Troubling Shooting

If recieving CUDA Setup failed despite GPU being available. Identify the location of the cuda driver, typically found under /usr/local/ and input the following commands via the command line. The example below shows this for cuda-12.3.:

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3 # change 12.3 to appropriate location
export BNB_CUDA_VERSION=123 # 123 (i.e., 12.3) also needs to be changed

Usage

from TokenProbs import LogitExtractor

extractor = LogitExtractor(
    model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
    quantization="8bit" # None = Full precision, "4bit" also suported
)

# Test sentence
sentence = "AAPL shares were up in morning trading, but closed even on the day."

# Prompt sentence
prompt = \
"""Instructions: What is the sentiment of this news article? Select from {positive/neutral/negative}.
\nInput: %text_input
Answer:"""

prompted_sentence = prompt.replace("%text_input",sentence)

# Provide tokens to extract (can be TokenIDs or strings)
pred_tokens = ['positive','neutral','negative']


# Extract normalized token probabilities
probabilities = extractor.logit_extraction(
    input_data = prompted_sentence,
    tokens = pred_tokens,
    batch_size=1
)

print(f"Probabilities: {probabilities}")
Probabilities: {'positive': 0.7, 'neutral': 0.2, 'negative': 0.1}

# Compare to text output
text_output = extractor.text_generation(input_data,batch_size=1)

Additional Features

LogitExtractor also provides functionality for applying Low-rank Adaptation (LoRA) fine-tuning tailored to extracting logit scores for next-token predictions.

from datasets import load_dataset
from TokenProbs import LogitExtractor

# Load dataset
dataset = load_dataset("financial_phrasebank",'sentences_50agree')['train']
# Apply training and test split
dataset = dataset.train_test_split(seed=42)
train = dataset['train']

# Convert class labels to text
labels = [{0:'negative',1:'neutral',2:'positive'}[i] for i in train['label']]
# Get sentences 
prompted_sentences = [prompt.replace("%text_input",sent) for sent in train['sentence']]

# Add labels to prompted sentences
training_texts = [prompted_sentences[i] + labels[i] for i in range(len(labels))]

# Load model
extractor = LogitExtractor(
    model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
    quantization="8bit"
)

# Set up SFFTrainer
extractor.trainer_setup(
    train_ds = training_texts, #either a dataloader object or text list
    response_seq = "\nAnswer:", # Tells trainer to train only on text following "\nAnswer: "
    # Input can be text string or list of TokenIDs. Be careful, tokens can differ based on context.
    lora_alpha=16,
    lora_rank=32,
    lora_dropout=0.1
)
extractor.trainer.train()
# Push model to huggingface
extractor.trainer.model.push_to_hub('<HF_USERNAME>/<MODEL_NAME>')

# Load model later
trained_model = extractor(
    model_name = '<HF_USERNAME>/<MODEL_NAME>',
    quantization="8bit"
)

Examples

Coming soon.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

TokenProbs-1.0.0.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

TokenProbs-1.0.0-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file TokenProbs-1.0.0.tar.gz.

File metadata

  • Download URL: TokenProbs-1.0.0.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for TokenProbs-1.0.0.tar.gz
Algorithm Hash digest
SHA256 c977d927794b7a688f6ae9765be0a22e0593fca38f603f9ca2bb7a36a167b45d
MD5 daaafd72133853e63b1f5b3cf5c02b37
BLAKE2b-256 99d6ce1ce457ff2be06444bcc5d4d979885d3c50ca5df0423c8182c2d523e31b

See more details on using hashes here.

File details

Details for the file TokenProbs-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: TokenProbs-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for TokenProbs-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 1c74ebbfb8c7fe3fc5657243481e70e7090acf3bad241cccaa211e6d1bb9efd9
MD5 967782e46d905ac94b57ac7beba9ed33
BLAKE2b-256 3f1f12058bfb856572943855b6065a7d7b54244c13525501e7f1164e569ca5e1

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page