Skip to main content

GSPL: Gen-Selective Pseudo Labeling for Domain Adaptation, based on 🤗 Datasets and Serverless Inference API.

Project description

GSPL: Gen-Selective Pseudo Labeler

Python License Maintainer

Introduction

Plot The Gen-Selective Pseudo Labeler (GSPL) is a tool designed to generate multiple queries about the same document and apply selection techniques using transformers. It leverages large language models (LLMs) to function as a judge, enhancing the labeling process through advanced completion and selection methods. This project is particularly useful in scenarios where multi-query generation and selection are required to improve the quality and relevance of labeled data.

Features

  • Multi-query Generation: Generate multiple queries for a single document using powerful LLMs.
  • Selective Labeling: Apply selection techniques to choose the best query from the generated set.
  • Parallel Processing: Utilize multi-threading to efficiently process large datasets.
  • Integration with Hugging Face Transformers: Seamlessly connect with transformer models hosted on Hugging Face.
  • Customizable Templates: Apply custom chat templates to format queries and responses.

Requirements

To use GSPL, you need the following dependencies installed:

  • Python 3.7+
  • Datasets
  • tqdm
  • langdetect
  • tiktoken
  • python-dotenv
  • concurrent.futures

You can install the required dependencies using the following command:

pip install datasets tqdm langdetect python-dotenv tiktoken

Installation

To install GSPL, clone the repository and navigate to the project directory:

git clone https://github.com/yourusername/gspl.git

Usage

Initialization

To initialize the GSPL class, provide the API key, dataset, and optional parameters for dataset split, streaming, and rate limits.

from gspl import GSPL

api_key = "your_api_key"
dataset = "your_dataset"

gspl_instance = GSPL(api_key=api_key, dataset=dataset)

Methods

__init__(self, api_key: str, dataset: Dataset, split: Optional[str] = "train", streaming: Optional[bool] = False, rpm: Optional[int] = 30) -> None

Initializes the GSPL class with the specified parameters.

  • Parameters:
    • api_key (str): API key for the completion client.
    • dataset (Dataset): The dataset to be labeled.
    • split (str, optional): The dataset split to be used (default is "train").
    • streaming (bool, optional): Whether to stream the dataset (default is False).
    • rpm (int, optional): Requests per minute limit for rate limiting (default is 30).

__call__(self, column: str, completion_system_prompt, selection_system_prompt, payload: dict, save_to: str, api_url: Optional[str] = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct", response_key: Optional[str] = "query", validate_payload: Optional[bool] = False)

Executes the labeling process by applying chat templates, generating completions, and selecting the best responses.

  • Parameters:
    • column (str): Column name containing the query text.
    • completion_system_prompt (str): System prompt for generating completions.
    • selection_system_prompt (str): System prompt for selecting the best responses.
    • payload (dict): Payload to be sent to the completion API.
    • save_to (str): File path to save the labeled dataset.
    • api_url (str, optional): API URL for the completion client (default is "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct").
    • response_key (str, optional): Key to extract the response from the completion client (default is "query").
    • validate_payload (bool, optional): Whether to validate the payload before sending (default is False).

Example Usage

from gspl import GSPL

completion_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are an AI assistant specialized in creating targeted questions for documents to build a domain-specific dataset for embedding model training. Your task is to:

    1. Analyze the given document or text.
    2. Generate a highly specific question that directly relates to the main content or key information in the document.
    3. Ensure the question is tailored to retrieve the document's content when used as a search query.
    4. Format the question as a JSON object with a single "query" key.
    5. Provide only the JSON object as output, without any additional text, introductions, or conclusions.

    Remember:
    - The question should be precise and relevant to the document's core information.
    - Avoid generic questions; focus on unique aspects of the given text.
    - Ensure the JSON is valid and can be parsed in Python.
    - Do not include any explanations or additional text outside the JSON object.<|eot_id|><|start_header_id|>user<|end_header_id|>
    {document}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """

selection_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are an AI assistant specialized in selecting the most appropriate question from a batch of queries to match specific content. Your task is to:

    1. Analyze the given content or document.
    2. Review the provided batch of queries.
    3. Select the question that best relates to the main content or key information in the document.
    4. Ensure the selected question is highly specific and tailored to retrieve the document's content when used as a search query.
    5. Format the selected question as a JSON object with a single "best_query" key.
    6. Provide only the JSON object as output, without any additional text, introductions, or conclusions.

    Remember:
    - The selected question should be the most precise and relevant to the document's core information.
    - Prioritize questions that focus on unique aspects of the given text.
    - Ensure the JSON is valid and can be parsed in Python.
    - Do not include any explanations or additional text outside the JSON object.<|eot_id|><|start_header_id|>user<|end_header_id|>
    {queries}
    Source text : {document}
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """

payload = {
    "parameters": {
        "temperature": 0.9,
        "return_full_text": False,
        "max_new_tokens": 250,
        "do_sample": True,
        "top_k": 50,
        "top_p": 0.95
    },
    "options": {
        "use_cache": False,
        "wait_for_model": True
    }
}

gspl = GSPL(
    api_key="api_key",
    dataset=dataset["datasetId"],
    split="train"
)

gspl.apply_chat_template(
    template=completion_template,
    output="inputs",
    document="output",
)

gspl.label(
    payload=payload,
    output="queries"
)

gspl.apply_chat_template(
    template=selection_template,
    output="inputs",
    queries="queries",
    document="output"
)

gspl.select(
    payload=payload,
)

gspl.to_parquet(
    filepath=output_file
)

Methods Explained

apply_chat_template(self, column: str, system_prompt: str, output: Optional[str] = "inputs") -> None

Applies a chat template to a dataset in parallel using multiple processes.

  • Parameters:
    • template (str): The template prompt to use for the chat template.
    • output (str, optional): The name of the column where the processed template will be saved (default is "inputs").
    • **kwargs_columns (str): Mapping of template placeholder names to dataset column names.
  • Returns:
    • self.dataset (Dataset): The dataset is updated in-place with the chat template applied.

completion(self, payload: dict, api_url: Optional[str] = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct", response_key: Optional[str] = "query", validate_payload: bool = False) -> str

Generates a completion response using the completion client.

  • Parameters:
    • payload (dict): The payload to be sent to the completion API.
    • api_url (str, optional): The API URL for the completion client (default is "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct").
    • response_key (str, optional): The key to extract the response from the completion client (default is "query").
    • validate_payload (bool, optional): Whether to validate the payload before sending (default is False).
  • Returns:
    • str: The completion response from the API.

label(self, payload: dict, api_url: Optional[str] = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct", inputs_column: Optional[str] = "inputs", response_key: Optional[str] = "query", output: Optional[str] = "queries", num_return_sequences: Optional[int] = 3, token_threshold: Optional[int] = 600, validate_payload: bool = False) -> Dataset

Labels the dataset using the completion client.

  • Parameters:
    • payload (dict): The payload to be sent to the completion API.
    • api_url (str, optional): The API URL for the completion client (default is "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct").
    • inputs_column (str, optional): The name of the column containing the input text (default is "inputs").
    • response_key (str, optional): The key to extract the response from the completion client (default is "query").
    • output (str, optional): The name of the column to store the output (default is "queries").
    • num_return_sequences (int, optional): The number of response sequences to generate for each input (default is 3).
    • token_threshold (int, optional): The token threshold for the completion response (default is 600).
    • validate_payload (bool, optional): Whether to validate the payload before sending (default is False).
  • Returns:
    • Dataset: The labeled dataset.

select(self, payload: dict, api_url: Optional[str] = "https://api-inference.huggingface.co/models/microsoft/Phi-3-medium-4k-instruct", inputs_column: Optional[str] = "inputs", response_key: Optional[str] = "query", output: Optional[str] = "query", validate_payload: bool = False) -> None

Selects outputs from the labeled dataset.

  • Parameters:
    • payload (dict): The payload to be sent to the completion API.
    • api_url (str, optional): The API URL for the completion client (default is "https://api-inference.huggingface.co/models/microsoft/Phi-3-medium-4k-instruct").
    • inputs_column (str, optional): The name of the column containing the input text (default is "inputs").
    • response_key (str, optional): The key to extract the response from the completion client (default is "query").
    • output (str, optional): The name of the column to store the output (default is "query").
    • validate_payload (bool, optional): Whether to validate the payload before sending (default is False).
  • Returns:
    • None

to_parquet(self, filepath: str) -> None

Saves the labeled dataset to a Parquet file.

  • Parameters:
    • filepath (str): The file path to save the Parquet file.
  • Returns:
    • None

Citing this project

If you use this code in your research, please use the following BibTeX entry.

@misc{louisbrulenaudet2024,
	author = {Louis Brulé Naudet},
	title = {GSPL: Gen-Selective Pseudo Labeler},
	howpublished = {\url{https://github.com/louisbrulenaudet/gspl}},
	year = {2024}
}

Feedback

If you have any feedback, please reach out at louisbrulenaudet@icloud.com.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

gspl_api-0.0.9.tar.gz (21.1 kB view details)

Uploaded Source

Built Distribution

gspl_api-0.0.9-py3-none-any.whl (22.0 kB view details)

Uploaded Python 3

File details

Details for the file gspl_api-0.0.9.tar.gz.

File metadata

  • Download URL: gspl_api-0.0.9.tar.gz
  • Upload date:
  • Size: 21.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for gspl_api-0.0.9.tar.gz
Algorithm Hash digest
SHA256 858b06826d6fb73f7c7a88d4a549e5f2e180149aebed9f6b0104f6dddb31d38e
MD5 e281a397926e7ff9f6430d235a8ce9c3
BLAKE2b-256 eac55ad4bc6359cc92f60b5a43d0fd96be5e5f2ff41d60545b7c7dfa6f75a17e

See more details on using hashes here.

File details

Details for the file gspl_api-0.0.9-py3-none-any.whl.

File metadata

  • Download URL: gspl_api-0.0.9-py3-none-any.whl
  • Upload date:
  • Size: 22.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.11.7

File hashes

Hashes for gspl_api-0.0.9-py3-none-any.whl
Algorithm Hash digest
SHA256 5d82985a507d91e6841d08cbe5c58470bb4d205b240d9ffe4a55fb55413ec514
MD5 e9e623e4771b57f374d0049257c50b8e
BLAKE2b-256 3b2a8b2c28ccb4dad166af5f055605d4e461b69d87cf2d40f6a0c95f15887c29

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page