Diagnose and find hallucinations in your grounded Large Language Model prompts with Anchor-GPT!
Project description
Find hallucination prone prompts and use them to fine-tune / ground your LLM
Why Anchor GPT ?
Because you can't get groundtruth answers for every prompt and fine-tuning / grounding with the right data gives much better results. We compared side by side fine-tuning with prompts sampled randomly and with CoreSet (the core algo of anchor-gpt) and the results speak for themselves 👇
Accuracy on a sample of the MMLU test dataset of a fine tuned LLama with 1000 datapoints sampled from the Alpaca dataset using either Random sampling or CoreSet
Installation
pip install anchor-gpt
Step by Step
- Use prompt logger to log your prompts and their grouding scores
from anchor_gpt import PromptLogger, Prompt
# Your regular grounding process
prompt_embeddings = embedding_model.encode(prompt)
index_response = my_index_endpoint.find_neighbors(
queries=prompt_embeddings,
num_neighbors=10,
)
grounding_data = []
grounding_distances = []
for grounding_index, grounding_distance in index_response:
grounding_data.append(my_index_endpoint.get(grounding_index))
grounding_distances.append(grounding_distance)
grounded_prompt = build_prompt(prompt, grounding_data)
# Call your LLM
chat_response = my_llm.chat(grounded_prompt, temperature=0.1)
# Log the prompt
prompt_logger = PromptLogger()
my_prompt = prompt_logger.log(Prompt(
text=prompt,
response=chat_response,
scores={'grounding_distances': grounding_distances},
embeddings=prompt_embeddings,
))
- Add additional scores like user feedback asynchronously
my_prompt.update_scores({'user_feedback': 0.8})
- Retreive the worst performing prompts to fine-tune your model or improve your grounding database
# Define a custom prompt scoring method
def retriever(store, threshold):
def prompt_average_score(prompt):
return 0.2 * prompt.scores['grounding_distances'][0] + 0.8 * prompt.scores['user_feedback']
return list(filter(lambda x: prompt_average_score(x) > threshold, store.select_prompts()))
# Retreive the ones above a threshold
worst_prompts = prompt_logger.retrieve(retriever, 0.5)
# Remove near duplicates to only keep what matters
deduped_prompts = prompt_logger.deduplicate(worst_prompts, 100)
# Add the right answers to your grounding DB to better answer those prompts next time
Example in a chat service
from anchor_gpt import PromptLogger, Prompt
prompt_logger = PromptLogger()
# Your regular chat endpoint with logging enabeled
@app.route("/chat", methods=["POST"])
def chat():
# Do your grounding as normal:
prompt_embeddings = model.encode(request.json["prompt"])
vector_store_results = vector_store.query(prompt_embeddings, top_k=10)
grounded_prompt = build_prompt(prompt, vector_store_results)
chat_response = my_llm.chat(grounded_prompt, temperature=0.1)
# Then log the prompt with the response, scores and embeddings.
# Prompts are stored locally in a SQLite database.
prompt_logger.log(Prompt(
text=request.json["prompt"],
response=chat_response,
scores={'grounding_distances': [r.distance for r in vector_store_results]},
embeddings=prompt_embeddings,
))
return chat_response
# A new hallucination retreival endpoint to get the worst prompts from your LLM
@app.route("/hallucinations", methods=["GET"])
def hallucinations():
def retriever(store, threshold):
def prompt_average_score(prompt):
return prompt.scores['grounding_distances'][0]
return list(filter(lambda x: prompt_average_score(x) > threshold, store.select_prompts()))
# Retrieve a list of the prompts with the greated distance from your grounding data
worst_prompts = prompt_logger.retrieve_n(0.5)
# Remove near duplicates and only keep 10 prompts
deduped_prompts = prompt_logger.deduplicate(worst_prompts, 10)
# Clean up the store
prompt_logger.store.purge()
return jsonify([{'text': p.text, 'response': p.response} for p in deduped_prompts])
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file anchor-gpt-0.0.2.tar.gz
.
File metadata
- Download URL: anchor-gpt-0.0.2.tar.gz
- Upload date:
- Size: 13.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2fe96f31246529311c4698743a5c536a4ca9a9613f0271556448874e584a0d89 |
|
MD5 | 609e2a2f017ec99986c6c46bcf60433c |
|
BLAKE2b-256 | 877e8058d65279d94a0dc15a835e7bafd5c4007708b7628a5686782d2bd71068 |
File details
Details for the file anchor_gpt-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: anchor_gpt-0.0.2-py3-none-any.whl
- Upload date:
- Size: 12.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 768ad210273cda2d693335ebd6c6cf525842a9a3ac0cf1bbce581ecfbfd32cfc |
|
MD5 | 6c70c9223d07de5ca0fafac629b55eec |
|
BLAKE2b-256 | ec2ae96f3b9f9a5b8a95c25b377bea4b339956c03a113605cf2cf1493e430184 |