Skip to main content

Diagnose and find hallucinations in your grounded Large Language Model prompts with Anchor-GPT!

Project description

Find hallucination prone prompts and use them to fine-tune / ground your LLM

PyPI

Why Anchor GPT ?

Because you can't get groundtruth answers for every prompt and fine-tuning / grounding with the right data gives much better results. We compared side by side fine-tuning with prompts sampled randomly and with CoreSet (the core algo of anchor-gpt) and the results speak for themselves 👇


Accuracy on a sample of the MMLU test dataset of a fine tuned LLama with 1000 datapoints sampled from the Alpaca dataset using either Random sampling or CoreSet

Installation

pip install anchor-gpt

Step by Step

  1. Use prompt logger to log your prompts and their grouding scores
from anchor_gpt import PromptLogger, Prompt

# Your regular grounding process
prompt_embeddings = embedding_model.encode(prompt)
index_response = my_index_endpoint.find_neighbors(
    queries=prompt_embeddings,
    num_neighbors=10,
)

grounding_data = []
grounding_distances = []
for grounding_index, grounding_distance in index_response:
    grounding_data.append(my_index_endpoint.get(grounding_index))
    grounding_distances.append(grounding_distance)

grounded_prompt = build_prompt(prompt, grounding_data)

# Call your LLM
chat_response = my_llm.chat(grounded_prompt, temperature=0.1)


# Log the prompt
prompt_logger = PromptLogger()
my_prompt = prompt_logger.log(Prompt(
    text=prompt,
    response=chat_response,
    scores={'grounding_distances': grounding_distances},
    embeddings=prompt_embeddings,
))
  1. Add additional scores like user feedback asynchronously
my_prompt.update_scores({'user_feedback': 0.8})
  1. Retreive the worst performing prompts to fine-tune your model or improve your grounding database
# Define a custom prompt scoring method
def retriever(store, threshold):
    def prompt_average_score(prompt):
        return 0.2 * prompt.scores['grounding_distances'][0] + 0.8 * prompt.scores['user_feedback']

    return list(filter(lambda x: prompt_average_score(x) > threshold, store.select_prompts()))

# Retreive the ones above a threshold
worst_prompts = prompt_logger.retrieve(retriever, 0.5)
# Remove near duplicates to only keep what matters
deduped_prompts = prompt_logger.deduplicate(worst_prompts, 100)

# Add the right answers to your grounding DB to better answer those prompts next time

Example in a chat service

from anchor_gpt import PromptLogger, Prompt

prompt_logger = PromptLogger()

# Your regular chat endpoint with logging enabeled
@app.route("/chat", methods=["POST"])
def chat():
    # Do your grounding as normal:
    prompt_embeddings = model.encode(request.json["prompt"])
    vector_store_results = vector_store.query(prompt_embeddings, top_k=10)

    grounded_prompt = build_prompt(prompt, vector_store_results)
    chat_response = my_llm.chat(grounded_prompt, temperature=0.1)

    # Then log the prompt with the response, scores and embeddings.
    # Prompts are stored locally in a SQLite database.
    prompt_logger.log(Prompt(
        text=request.json["prompt"],
        response=chat_response,
        scores={'grounding_distances': [r.distance for r in vector_store_results]},
        embeddings=prompt_embeddings,
    ))

    return chat_response

# A new hallucination retreival endpoint to get the worst prompts from your LLM
@app.route("/hallucinations", methods=["GET"])
def hallucinations():
    def retriever(store, threshold):
        def prompt_average_score(prompt):
                return prompt.scores['grounding_distances'][0]
            return list(filter(lambda x: prompt_average_score(x) > threshold, store.select_prompts()))

    # Retrieve a list of the prompts with the greated distance from your grounding data
    worst_prompts = prompt_logger.retrieve_n(0.5)

    # Remove near duplicates and only keep 10 prompts
    deduped_prompts = prompt_logger.deduplicate(worst_prompts, 10)

    # Clean up the store
    prompt_logger.store.purge()

    return jsonify([{'text': p.text, 'response': p.response} for p in deduped_prompts])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

anchor-gpt-0.0.2.tar.gz (13.3 kB view details)

Uploaded Source

Built Distribution

anchor_gpt-0.0.2-py3-none-any.whl (12.6 kB view details)

Uploaded Python 3

File details

Details for the file anchor-gpt-0.0.2.tar.gz.

File metadata

  • Download URL: anchor-gpt-0.0.2.tar.gz
  • Upload date:
  • Size: 13.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for anchor-gpt-0.0.2.tar.gz
Algorithm Hash digest
SHA256 2fe96f31246529311c4698743a5c536a4ca9a9613f0271556448874e584a0d89
MD5 609e2a2f017ec99986c6c46bcf60433c
BLAKE2b-256 877e8058d65279d94a0dc15a835e7bafd5c4007708b7628a5686782d2bd71068

See more details on using hashes here.

File details

Details for the file anchor_gpt-0.0.2-py3-none-any.whl.

File metadata

  • Download URL: anchor_gpt-0.0.2-py3-none-any.whl
  • Upload date:
  • Size: 12.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.6

File hashes

Hashes for anchor_gpt-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 768ad210273cda2d693335ebd6c6cf525842a9a3ac0cf1bbce581ecfbfd32cfc
MD5 6c70c9223d07de5ca0fafac629b55eec
BLAKE2b-256 ec2ae96f3b9f9a5b8a95c25b377bea4b339956c03a113605cf2cf1493e430184

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page