A benchmark for LLM calibration on human populations.
Project description
:book: folktexts
Folktexts is a python package to evaluate and benchmark calibration of large language models. It enables using any transformers model as a classifier for tabular data tasks, and extracting risk score estimates from the model's output log-odds.
Several benchmark tasks are provided based on data from the American Community Survey. Namely, each prediction task from the popular folktables package is made available as a natural-language prompting task.
Package documentation can be found here.
Table of contents:
Installing
Install package from PyPI:
pip install folktexts
Basic setup
- Create condo environment
conda create -n folktexts python=3.11
conda activate folktexts
- Install folktexts package
pip install folktexts
- Create models dataset and results folder
mkdir results
mkdir models
mkdir datasets
- Download transformers model and tokenizer into models folder
python -m folktexts.cli.download_models --model "google/gemma-2b" --save-dir models
- Run benchmark
python -m folktexts.cli.run_acs_benchmark --results-dir results --data-dir datasets --task-name "ACSIncome" --model models/google--gemma-2b
Run python -m folktexts.cli.run_acs_benchmark --help
to get a list of all
available benchmark flags.
Usage
from folktexts.acs import ACSDataset, ACSTaskMetadata
acs_task_name = "ACSIncome"
# Create an object that classifies data using an LLM
clf = LLMClassifier(
model=model,
tokenizer=tokenizer,
task=ACSTaskMetadata.get_task(acs_task_name),
)
# Use a dataset or feed in your own data
dataset = ACSDataset(acs_task_name)
# Get risk score predictions out of the model
y_scores = clf.predict_proba(dataset)
# Optionally, can fit the threshold based on a small portion of the data
clf.fit(dataset[0:100])
# ...in order to get more accurate binary predictions
clf.predict(dataset)
# Compute a variety of evaluation metrics on calibration and accuracy
from folktexts.benchmark import CalibrationBenchmark
benchmark_results = CalibrationBenchmark(clf, dataset, results_dir="results").run()
License and terms of use
Code licensed under the MIT license.
The American Community Survey (ACS) Public Use Microdata Sample (PUMS) is governed by the U.S. Census Bureau terms of service.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for folktexts-0.0.6-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a80bb4b0cd5f729d6a8f0bbdb22cbaa7ce3351d86e27200e39a67878b6336661 |
|
MD5 | 100503f9e707d62435fdac7db45702e2 |
|
BLAKE2b-256 | 9e930d2f6357effd0a26eb5633a7840fb67a2b83da5496975aa8d75201b38fdb |