Extract token-level probabilities from LLMs for classification-type outputs.
Project description
TokenProbs
Extract token-level probability scores from generative language models (GLMs) without fine-tuning. Often times, it is relevent to request probability assessment to binary or multi-class outcomes. GLMs are not well-suited for this task. Instead, use LogitExtractor
to obtain label probabilities without fine-tuning.
Installation
Install with pip
:
conda create -n TokenProbs python=3.11 # Note: not available for 3.13
conda activate TokenProbs
pip3 install TokenProbs
Install via Github Repository:
conda create -n TokenProbs python=3.12 # Note: not available for 3.13
conda activate TokenProbs
git clone https://github.com/francescoafabozzi/TokenProbs.git
cd TokenProbs
pip3 install -e . # Install in editable mode
Usage
See examples/FinancialPhrasebank.ipynb
for an example of using LogitExtractor
to extract token-level probabilities for a sentiment classification task.
from TokenProbs import LogitExtractor
extractor = LogitExtractor(
model_name = 'mistralai/Mistral-7B-Instruct-v0.1',
quantization="8bit" # None = Full precision, "4bit" also suported
)
# Test sentence
sentence = "AAPL shares were up in morning trading, but closed even on the day."
# Prompt sentence
prompt = \
"""Instructions: What is the sentiment of this news article? Select from {positive/neutral/negative}.
\nInput: %text_input
Answer:"""
prompted_sentence = prompt.replace("%text_input",sentence)
# Provide tokens to extract (can be TokenIDs or strings)
pred_tokens = ['positive','neutral','negative']
# Extract normalized token probabilities
probabilities = extractor.logit_extraction(
input_data = prompted_sentence,
tokens = pred_tokens,
batch_size=1
)
print(f"Probabilities: {probabilities}")
Probabilities: {'positive': 0.7, 'neutral': 0.2, 'negative': 0.1}
# Compare to text output
text_output = extractor.text_generation(input_data,batch_size=1)
Trouble Shooting Installation
Import Errors due to torch
If recieving import errors due to torch
, specific torch version may be required. Follow the steps below:
Step 1: Identify the CUDA versions (for GPU users):
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Wed_Nov_22_10:17:15_PST_2023
Cuda compilation tools, release 12.3, V12.3.107
Build cuda_12.3.r12.3/compiler.33567101_0
In this case, the CUDA version is 12.3.
Step 2: Navigate to the Pytorch website and select the version that matches the CUDA version.
There is no cuda version for 12.3, so select torch CUDA download < 12.3 (i.e., 12.1)
Step 3: Pip uninstall torch and download with the correct version:
pip3 uninstall torch
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
Issues with bitsandbytes
If recieving CUDA Setup failed despite GPU being available. error, identify the location of the cuda driver, typically found under /usr/local/ and input the following commands via the command line. The example below shows this for cuda-12.3.:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3 # change 12.3 to appropriate location
export BNB_CUDA_VERSION=123 # 123 (i.e., 12.3) also needs to be changed
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file tokenprobs-1.0.3.tar.gz
.
File metadata
- Download URL: tokenprobs-1.0.3.tar.gz
- Upload date:
- Size: 11.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 064c6b2563f4e209e0393fa3d067218d6ea3615af3b1bb3c86da0e0ce9e0a3d4 |
|
MD5 | f3efd0107a2a364ee1d15ad081c5fde0 |
|
BLAKE2b-256 | 592964c7b7dcc7a0f5d4bcb8af6538d368f40dae61fae97ceeb4db81878063d9 |
File details
Details for the file TokenProbs-1.0.3-py3-none-any.whl
.
File metadata
- Download URL: TokenProbs-1.0.3-py3-none-any.whl
- Upload date:
- Size: 11.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.12.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | e5a691042c745c82b73a284c0fca59a4c0984d9444c87797236422fadd7d89e0 |
|
MD5 | a8dd330497e99e1687e1f1a804ca215c |
|
BLAKE2b-256 | 14dab1d0bfb16a03cec7bec097f17cb2766627e0ab0bd17ea72fd65972288f06 |