Retrieval-backed LLMs for math education
Project description
llm-math-education: Retrieval augmented generation for middle-school math question answering and hint generation
How can we incorporate trusted, external math knowledge in generated answers to student questions?
llm-math-education
is a Python package that implements basic retrieval augmented generation (RAG) and contains prompts for two primary use cases: general math question-answering (QA) and hint generation. It is currently designed to work only with the OpenAI generative chat API.
This project is hosted on GitHub. Feel free to open an issue with questions, comments, or requests.
Installation
The llm-math-education
package is available on PyPI.
pip install llm-math-education
Usage
We assume that OPENAI_API_KEY
is provided as an environment variable or set via openai.api_key = your_api_key
.
Preliminary setup: specify a directory in which to save the embedding database.
from pathlib import Path
demo_dir = Path("data") / "demo"
demo_dir.mkdir(exist_ok=True)
We'll use llm-math-education
to answer a student question.
student_question = "How do I identify common factors?"
These usage examples can be seen together in src/usage_demo.py.
Acquiring textbook data for retrieval augmented generation
To do retrieval augmented generation, we need data. We'll use an OpenStax Pre-algebra textbook as our retrieval data.
Note: the llm_math_education.openstax
module relies on requests
and beautifulsoup4
, which are not listed as dependencies. Install them yourself with pip
if you want to download and parse OpenStax textbooks.
from llm_math_education import openstax
prealgebra_textbook_url = "https://openstax.org/books/prealgebra-2e/pages/1-introduction"
textbook_data = openstax.cache_openstax_textbook_contents(prealgebra_textbook_url, demo_dir / "openstax")
df = openstax.get_subsection_dataframe(textbook_data)
>>> df.columns
Index(['title', 'content', 'index', 'chapter', 'section'], dtype='object')
The parsing code is probably very brittle; it has only been tested with the Pre-algebra textbook.
Creating an embedding lookup database from a dataframe
from llm_math_education import retrieval
db_name = "openstax_prealgebra"
text_column_to_embed = "content"
openstax_db = retrieval.RetrievalDb(demo_dir, db_name, text_column_to_embed, df)
openstax_db.create_embeddings()
openstax_db.save_df()
Loading an existing embedding database
Here, we compute the "distance" in embedding space between the student question and the documents in the database.
openstax_db = retrieval.RetrievalDb(demo_dir, "openstax_prealgebra", "content")
distances = openstax_db.compute_string_distances(student_question)
>>> distances
[0.21348877 0.24298186 0.25825211 ... 0.25500673 0.24491884 0.22458498]
Using the database to do retrieval augmented generation
Defining a retrieval strategy
from llm_math_education import retrieval_strategies
db_info = retrieval.DbInfo(
openstax_db,
max_texts=1,
)
strategy = retrieval_strategies.MappedEmbeddingRetrievalStrategy(
{
"openstax_section": db_info,
},
)
The key in the dictionary passed to the MappedEmbedding
retrieval strategy identifies the key to be replaced in the prompt, in Python string formatting notation.
Starting a chat conversation with RAG
We'll use a PromptManager
to build chat messages from a prompt, a retrieval strategy, and a user query.
from llm_math_education import prompt_utils
pm = prompt_utils.PromptManager()
pm.set_retrieval_strategy(strategy)
pm.set_intro_messages(
[
{
"role": "user",
"content": """Answer this question: {user_query}
Reference this text in your answer:
{openstax_section}""",
},
],
)
messages = pm.build_query(student_question)
>>> messages
[{'role': 'user', 'content': 'Answer this question: How do I identify common factors?'
''
'Reference this text in your answer:'
'We will now look at an expression containing a product that is raised to a power. Look for a pattern. The exponent applies to each of the factors. This leads to the Product to a Power Property for Exponents. An example with numbers helps to verify this property:'}]
We can pass the formatted messages to the OpenAI API.
import openai
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=messages,
)
assistant_message = completion["choices"][0]["message"]
>>> assistant_message
{
"role": "assistant",
"content": "To identify common factors, you need to look for a pattern in an expression containing a product raised to a power. The exponent applies to each of the factors in this case. \n\nFor example, let's consider the expression (ab)^2. Here, (ab) is the product, and the exponent 2 applies to both 'a' and 'b'. To identify the common factors, you can separate the product into its individual factors:\n\n(ab)^2 = ab * ab\n\nNow, you can see that both 'a' and 'b' appear as factors in the expression. Therefore, 'a' and 'b' are the common factors. By identifying the factors that appear in multiple terms, you can determine the common factors of an expression.\n\nUsing numbers to verify this property, suppose we have the expression (2*3)^2, which simplifies to (6)^2. In this case, the common factor is 6, as both 2 and 3 are factors of 6."
}
Using PromptManager for multi-turn chat conversations
Add stored messages to continue the conversation.
pm.add_stored_message(assistant_message)
messages = pm.build_query("I have a follow-up question...")
Clear stored messages to start a new conversation on the next call to build_query()
.
pm.clear_stored_messages()
Using built-in prompts for math QA or hint generation
from llm_math_education.prompts import mathqa as mathqa_prompts
pm.set_intro_messages(mathqa_prompts.intro_prompts["general_math_qa_intro"])
Development
See the developer's guide.
Primary contributor:
- Zachary Levonian (levon003@umn.edu)
Other contributors:
- Owen Henkel
- Bill Roberts
FAQ
-
How can I cite this work?
You can cite this using the CITATION.cff file above (and the "Cite this repository" drop-down on GitHub for BibTeX).
Levonian, Z., Henkel, O., & Roberts, B. (2023). llm-math-education: Retrieval augmented generation for middle-school math question answering and hint generation (Version 0.5.1) [Computer software]. https://doi.org/10.5281/zenodo.8284412
We hope to publish a whitepaper describing this work and its evaluation with students in more detail; we'll update the CITATION.cff file if that happens.
-
How should I use this code?
We aren't currently planning to add additional features to this package, although pull requests and bug reports are welcome.
You should use the Python package as a dependency if you want a quick way to try retrieval augmented generation with the OpenAI API. However, this code is likely more useful as inspiration. You should fork or otherwise borrow from various components if you want some of the specific functionality implemented here. Heres a quick overview of the most important modules and their implementation:
llm_math_education.prompts.{mathqa,hints}
- Contains the prompt templates we use for math QA and hint generation.llm_math_education.prompt_utils
-PromptManager
is an abstraction for iteratively creating conversations that include a retrieval component.llm_math_education.retrieval_strategies
-RetrievalStrategy
and its implementations demonstrates implementations that use embeddings to fill a slot within a prompt template with relevant documents.llm_math_education.retrieval
-RetrievalDb
creates an embedding-backed in-memory lookup database for a Pandas DataFrame with a text column.llm_math_education.logit_bias
- Using the most frequent tokens in a retrieved document, creates a logit_bias that can be used to increase the faithfulness of generations based on that retrieved document.
-
What license does this repository use?
The code is released under the MIT license. The example data used in the Streamlit app is released CC BY-SA 4.0; see the
data/app_data
folder for more info. Additional details on the data are present in the developer's guide.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for llm_math_education-0.5.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b7db4434747d00fcce488bf8226be065929b5700dd74bd7ca882bc985dc7de02 |
|
MD5 | f6ed95212566a84f4c2e2e160884c9a2 |
|
BLAKE2b-256 | 3056f99f4d5fd3e05f223ca96c88a4fa4412af26d0b90b5f999c360311812571 |