Skip to main content

A python package for LLM Chain Ensembles!

Project description

LLM Chain Ensembles

This repository contains a python package and the code for the paper LLM Chain Ensembles for Scalable and Accurate Data Annotation. LLM Chain Ensemble's use a sequence of LLMs to label subsets of data selected using uncertainty estimates. This method reduces zero-shot prediction costs by exposing limited data to high-cost models at the end of the chain and can yield increased performance.

---
title: LLM Zero-Shot Prediction with LLM Chain Ensemble
---

flowchart LR
    A(Data) --> B{LLM 1}
    B --Lowest 1/3 Confidence--> C(Fowarded Data)
    B --High Confidence Examples--> G(Labeled Subset 1)
    C --> D{LLM 2}
    D --> H(Labeled Subset 2)
    D --Lowest 1/2 Confidence--> E(Fowarded Data)
    E --> F{LLM 3}
    F --> O(Labeled Subset 3)
    G --> M(Rank Based Ensemble)
    H --> M
    O --> M
    M --> Z(Final Label)

Getting Started

Installation

Install the most recent release with pip.

uv pip install chain-ensembles

We recommend using uv to manage your python environment and packages.

uv venv
uv pip install chain-ensembles

Authentication

The packages classes HuggingFaceLink and OpenAILink require some API tokens for authentication. These are accessed via environmental variables. If you are running from the terminal, please set the following environmental variables.

$ export HF_TOKEN=your_hf_token_here!
$ export OPENAI_API_TOKEN=your_openai_token_here!

Example Scripts

There are some example scripts available in the scripts/ directory of this repository. We used these to drive the experiments for the paper. These scripts can be ran directly or serve as reference for any developer hoping to make their own chain ensembles! To run the scripts be sure to clone the repository and install the chain_ensembles package. Read more here.

LLM Chain Ensembles

In this section, we will cover the basic functionality provided by the package.

Using LLM Links To Label Data

The smallest data labeling class is a link! We provide two sources for data labeling links:

HuggingFaceLink Example

For huggingface models you initialize the HuggingFaceLink, call .load_model() to load the model, and call .get_labels() on your prompts to prompt the model. Here's a brief example.

from chain_ensembles import HuggingFaceLink
from transformers import QuantoConfig, AutoModelForCausalLM

labels = ["against", "for", "neutral"]
llama_link = HuggingFaceLink(
    model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct",
    model_class = AutoModelForCausalLM,
    labels = labels,
    quantization_config = QuantoConfig('int2')
)
llama_link.load_model()
prompts = ["Classify the stance toward something. I'm against something"]
data_out = llama_link.get_labels(prompts)

OpenAILink

gpt4_link = OpenAILink(model_name="gpt-4o", labels = labels)

prompts = ["Classify the stance toward something. I'm against something"]
data_out = gpt4_link.get_labels(prompts)

Chaining LLMs

Putting it all together now!

import pandas as pd
from transformers import T5ForConditionalGeneration, AutoModelForCausalLM, QuantoConfig

from chain_ensembles import HuggingFaceLink, OpenAILink, LLMChain

data_df = pd.DataFrame({
    "prompts": ["Classify the stance toward something. I'm against something"]*12, 
    "Stance": ["against"]*12
})

labels = ["for", "against", "neutral"]

llama_link = HuggingFaceLink(
    model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct", 
    model_class = AutoModelForCausalLM, 
    labels = labels,
    quantization_config = QuantoConfig("int8")
)
flan_link = HuggingFaceLink(
    model_name = "google/flan-ul2", 
    model_class = T5ForConditionalGeneration, 
    labels = labels, 
    quantization_config = QuantoConfig("int8")
)
gpt4_link = OpenAILink(model_name="gpt-4o", labels = labels)

llm_chain = LLMChain(chain_list=[llama_link, flan_link, gpt4_link])
CoT_setting = [False, False, False]
data_out = llm_chain.run_chain(data_df, "./chain_out", CoT_setting)

Scripts

To run the scripts be sure to clone the repository and install the chain_ensembles package.

Chain Example Script

To run an example chain ensemble we provide an python script in chain_example.py that runs a chain of LLama3-8B-instruct, Flan-UL2, and GPT-4o. The code has three command line arguments:

  • -d The dataset to use. The user is expected to enter "SemEval2016", "misinfo" or "ibc". The SemEval2016 dataset is used for the stance detection task, the misinfo dataset is available for the misinformation detection task, and the ibc dataset is available for ideology detection task.
  • -o The directory to output the chain results.
  • -n The number of samples to select from said dataset. Specify -1 to use the entire dataset. Default is 10.
$ python scripts/chain_example.py -d SemEval2016 -o ./chain_out -n 20

Labeling Entire Datasets and Simulating Chain Ensembles

We also provide an interface for labeling entire datasets with LLMs and analyzing their results post-hoc to run simulations of chain ensembles. To label entire datasets we provide a python script in llm_label.py that runs with the following command line arguments.

  • -m The model to use to label the dataset. Select from "flan-ul2", "llama3-8B-instruct", "phi3-medium", "mistral-7B-instruct", "gpt-4o", and "gpt-4o-mini".
  • -d The dataset to label. The user is expected to enter "SemEval2016", "misinfo" or "ibc."
  • -o The output directory to save labeled data.
  • -n The number of samples to select from said dataset. Specify -1 to use the entire dataset. Default is 10.
  • q Specify model quantization. Supported are "8" and "4" for 8 and 4 bit respectively. To load with full precision specify -q != 8 or 4. Default is 8.
python ./scripts/llm_label.py -m llama3-8B-instruct -d "SemEval2016" -o "./llama_test/" -n -1 -q 0

Once you have fully labeled datasets, you can run a simple simulation across them in a separate notebook or python script using the functions provided by the chain_sim.py module. Set up the script with the following code.

import pandas as pd
from chain_ensembles import get_combinations, chain_dataframes, backward_pass

llama_link = pd.read_pickle("./llama_test/results_df.pkl")
flan_link =  pd.read_pickle("./flan_test/results_df.pkl")
gpt_link =  pd.read_pickle("./gpt_test/results_df.pkl")

links = [llama_link, flan_link, gpt_link]
names = ["llama", "flan", "gpt", "mistral", "phi"]

To run a single chain ensemble iteration, use the chain_dataframes and backward_pass functions.

chained_df = chain_dataframes(links, "Stance")
ensembled_df = backward_pass(chained_df, len(links))

To run a chain ensemble for all length 3 permutations of our links run the get_permutations function.

sim_results_df = get_permutations(links, names, 3, "Stance", n_trails=20, backward = True)

Note: We assume all datasets used in the simulation come from the labeling function. It's critical that the dataset columns are in the same order from the labeler. If you are using the llm_label.py script, then this is handled for you automatically.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

chain_ensembles-0.1.2.tar.gz (2.7 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

chain_ensembles-0.1.2-py3-none-any.whl (18.7 kB view details)

Uploaded Python 3

File details

Details for the file chain_ensembles-0.1.2.tar.gz.

File metadata

  • Download URL: chain_ensembles-0.1.2.tar.gz
  • Upload date:
  • Size: 2.7 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.5.6

File hashes

Hashes for chain_ensembles-0.1.2.tar.gz
Algorithm Hash digest
SHA256 70a5e2bcfb7bdb370af1a8c6fdb6005843e865ad2d5e908ce37648413c4d21df
MD5 52e12f1da0c32c4d7cd9deed6ca040bf
BLAKE2b-256 ed47ac614a5c91a66d07aa3da3f366b8da5824764947596af61bd4bddfd0bb49

See more details on using hashes here.

File details

Details for the file chain_ensembles-0.1.2-py3-none-any.whl.

File metadata

File hashes

Hashes for chain_ensembles-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 de792fa90b3d0c2ab7f3b256e1383361e5394011e593a8abc711f5e4e44120ec
MD5 e5ecba9af4fec91db31d8958909c069b
BLAKE2b-256 2e49cfcf18ad2c60d40b91d742b9a00a114617054828316e730de124c0109088

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page