A python package for LLM Chain Ensembles!
Project description
LLM Chain Ensembles
This repository contains a python package and the code for the paper LLM Chain Ensembles for Scalable and Accurate Data Annotation. LLM Chain Ensemble's use a sequence of LLMs to label subsets of data selected using uncertainty estimates. This method reduces zero-shot prediction costs by exposing limited data to high-cost models at the end of the chain and can yield increased performance.
---
title: LLM Zero-Shot Prediction with LLM Chain Ensemble
---
flowchart LR
A(Data) --> B{LLM 1}
B --Lowest 1/3 Confidence--> C(Fowarded Data)
B --High Confidence Examples--> G(Labeled Subset 1)
C --> D{LLM 2}
D --> H(Labeled Subset 2)
D --Lowest 1/2 Confidence--> E(Fowarded Data)
E --> F{LLM 3}
F --> O(Labeled Subset 3)
G --> M(Rank Based Ensemble)
H --> M
O --> M
M --> Z(Final Label)
Getting Started
Installation
Install the most recent release with pip.
uv pip install chain-ensembles
We recommend using uv to manage your python environment and packages.
uv venv
uv pip install chain-ensembles
Authentication
The packages classes HuggingFaceLink and OpenAILink require some API tokens for authentication. These are accessed via environmental variables. If you are running from the terminal, please set the following environmental variables.
$ export HF_TOKEN=your_hf_token_here!
$ export OPENAI_API_TOKEN=your_openai_token_here!
Example Scripts
There are some example scripts available in the scripts/ directory of this repository. We used these to drive the experiments for the paper. These scripts can be ran directly or serve as reference for any developer hoping to make their own chain ensembles! To run the scripts be sure to clone the repository and install the chain_ensembles package. Read more here.
LLM Chain Ensembles
In this section, we will cover the basic functionality provided by the package.
Using LLM Links To Label Data
The smallest data labeling class is a link! We provide two sources for data labeling links:
- Opensource models available on Huggingface with the transformers package.
- Closed source models available through the Open AI API.
HuggingFaceLink Example
For huggingface models you initialize the HuggingFaceLink, call .load_model() to load the model, and call .get_labels() on your prompts to prompt the model. Here's a brief example.
from chain_ensembles import HuggingFaceLink
from transformers import QuantoConfig, AutoModelForCausalLM
labels = ["against", "for", "neutral"]
llama_link = HuggingFaceLink(
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct",
model_class = AutoModelForCausalLM,
labels = labels,
quantization_config = QuantoConfig('int2')
)
llama_link.load_model()
prompts = ["Classify the stance toward something. I'm against something"]
data_out = llama_link.get_labels(prompts)
OpenAILink
gpt4_link = OpenAILink(model_name="gpt-4o", labels = labels)
prompts = ["Classify the stance toward something. I'm against something"]
data_out = gpt4_link.get_labels(prompts)
Chaining LLMs
Putting it all together now!
import pandas as pd
from transformers import T5ForConditionalGeneration, AutoModelForCausalLM, QuantoConfig
from chain_ensembles import HuggingFaceLink, OpenAILink, LLMChain
data_df = pd.DataFrame({
"prompts": ["Classify the stance toward something. I'm against something"]*12,
"Stance": ["against"]*12
})
labels = ["for", "against", "neutral"]
llama_link = HuggingFaceLink(
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct",
model_class = AutoModelForCausalLM,
labels = labels,
quantization_config = QuantoConfig("int8")
)
flan_link = HuggingFaceLink(
model_name = "google/flan-ul2",
model_class = T5ForConditionalGeneration,
labels = labels,
quantization_config = QuantoConfig("int8")
)
gpt4_link = OpenAILink(model_name="gpt-4o", labels = labels)
llm_chain = LLMChain(chain_list=[llama_link, flan_link, gpt4_link])
CoT_setting = [False, False, False]
data_out = llm_chain.run_chain(data_df, "./chain_out", CoT_setting)
Scripts
To run the scripts be sure to clone the repository and install the chain_ensembles package.
Chain Example Script
To run an example chain ensemble we provide an python script in chain_example.py that runs a chain of LLama3-8B-instruct, Flan-UL2, and GPT-4o. The code has three command line arguments:
-dThe dataset to use. The user is expected to enter "SemEval2016", "misinfo" or "ibc". The SemEval2016 dataset is used for the stance detection task, the misinfo dataset is available for the misinformation detection task, and the ibc dataset is available for ideology detection task.-oThe directory to output the chain results.-nThe number of samples to select from said dataset. Specify -1 to use the entire dataset. Default is 10.
$ python scripts/chain_example.py -d SemEval2016 -o ./chain_out -n 20
Labeling Entire Datasets and Simulating Chain Ensembles
We also provide an interface for labeling entire datasets with LLMs and analyzing their results post-hoc to run simulations of chain ensembles. To label entire datasets we provide a python script in llm_label.py that runs with the following command line arguments.
-mThe model to use to label the dataset. Select from "flan-ul2", "llama3-8B-instruct", "phi3-medium", "mistral-7B-instruct", "gpt-4o", and "gpt-4o-mini".-dThe dataset to label. The user is expected to enter "SemEval2016", "misinfo" or "ibc."-oThe output directory to save labeled data.-nThe number of samples to select from said dataset. Specify -1 to use the entire dataset. Default is 10.qSpecify model quantization. Supported are "8" and "4" for 8 and 4 bit respectively. To load with full precision specify -q != 8 or 4. Default is 8.
python ./scripts/llm_label.py -m llama3-8B-instruct -d "SemEval2016" -o "./llama_test/" -n -1 -q 0
Once you have fully labeled datasets, you can run a simple simulation across them in a separate notebook or python script using the functions provided by the chain_sim.py module. Set up the script with the following code.
import pandas as pd
from chain_ensembles import get_combinations, chain_dataframes, backward_pass
llama_link = pd.read_pickle("./llama_test/results_df.pkl")
flan_link = pd.read_pickle("./flan_test/results_df.pkl")
gpt_link = pd.read_pickle("./gpt_test/results_df.pkl")
links = [llama_link, flan_link, gpt_link]
names = ["llama", "flan", "gpt", "mistral", "phi"]
To run a single chain ensemble iteration, use the chain_dataframes and backward_pass functions.
chained_df = chain_dataframes(links, "Stance")
ensembled_df = backward_pass(chained_df, len(links))
To run a chain ensemble for all length 3 permutations of our links run the get_permutations function.
sim_results_df = get_permutations(links, names, 3, "Stance", n_trails=20, backward = True)
Note: We assume all datasets used in the simulation come from the labeling function. It's critical that the dataset columns are in the same order from the labeler. If you are using the llm_label.py script, then this is handled for you automatically.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file chain_ensembles-0.1.2.tar.gz.
File metadata
- Download URL: chain_ensembles-0.1.2.tar.gz
- Upload date:
- Size: 2.7 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
70a5e2bcfb7bdb370af1a8c6fdb6005843e865ad2d5e908ce37648413c4d21df
|
|
| MD5 |
52e12f1da0c32c4d7cd9deed6ca040bf
|
|
| BLAKE2b-256 |
ed47ac614a5c91a66d07aa3da3f366b8da5824764947596af61bd4bddfd0bb49
|
File details
Details for the file chain_ensembles-0.1.2-py3-none-any.whl.
File metadata
- Download URL: chain_ensembles-0.1.2-py3-none-any.whl
- Upload date:
- Size: 18.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
de792fa90b3d0c2ab7f3b256e1383361e5394011e593a8abc711f5e4e44120ec
|
|
| MD5 |
e5ecba9af4fec91db31d8958909c069b
|
|
| BLAKE2b-256 |
2e49cfcf18ad2c60d40b91d742b9a00a114617054828316e730de124c0109088
|