Speculative grammar backtracking algorithm for LLM decoding conforming to some lark context-free grammar (CFG)
Project description
Grammar Guide
Speculative grammar backtracking algorithm to perform grammar-constrained decoding with any text generation function (OpenAI, Anthropic, etc.)This repo is a slightly modified implementation of the decoding mechanism described in Section 3.2 of Grammar Prompting for Domain-Specific Language Generation with Large Language Models by @berlino. I refer to the general algorithm as speculative grammar backtracking.
It is a form of constrained decoding, and can be used to guide even proprietary, black-box LLM APIs according to some context-free grammar. It's rooted in the idea that as LLMs get better, not all steps of the decoding process need to apply a strict logit mask - like a good teacher, we let the student give an answer, and only intervene and correct when necessary.
When using local Transformer models, we can efficiently backtrack the KV cache according to a given Lark CFG. Below is a benchmark showing tokens/sec when generating a JSON object with [10, 20, 30, 40] keys using HuggingFaceTB/SmolLM-135M on my Macbook M1.
Features
- Compatible with any text generation function
- OpenAI, Anthropic etc. - as long as you can provide some
generate(prompt: str) -> str
function!
- OpenAI, Anthropic etc. - as long as you can provide some
- Efficient re-use of KV cache for all CausalLM Transformer models
- Optimistic, speculative decoding = no need to manually update to support new tokenizers
- Visualization and logging of grammar corrections
- Token healing to ensure high probability continuations
pip install grammar-guide
Examples
With Transformer Models
When using HuggingFace Transformer models, we get an extra speed boost by leveraging efficient caching and backtracking of the KV cache. When a grammar correction is made, we backtrack to the state of the KV cache aligned to the longest prefix that is valid under our Lark context-free grammar.
from transformers import AutoModelForCausalLM, AutoTokenizer
import guidance
import grammar_guide as gg
model_name_or_path = "HuggingFaceTB/SmolLM-135M"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
parser = gg.load_parser("../grammars/json_five_values_string_only.lark")
res = gg.guide(
draft_model=model,
tokenizer=tokenizer,
parser=parser,
prompt="Here's a JSON object with only string values:",
target_model=guidance.models.Transformers(
model_name_or_path, echo=False
),
stop_at=['```'],
token_lookahead=20,
max_grammar_corrections=20,
temperature=0.0
)
In the visualization below:
- Green highlights = text generated by the draft model
- Blue highlights = candidates selected by our target model
- Red text = backtracked text that violated the context-free grammar
- Orange text = tokens that were fed through the token healing logits processor
With General API-based Providers
import os
from openai import OpenAI
import guidance
import grammar_guide as gg
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
parser = gg.load_parser("../grammars/json_five_values_string_only.lark")
# Define our core completion predict function
# This just needs to follow the `fn(s: str) -> str` contract
# so we can use any black-box API provider.
def openai_generate(prefix: str, prompt: str, max_new_tokens: int) -> str:
messages = [
{
"role": "system",
"content": prompt
}
]
if prefix:
messages += [
{
"role": "assistant",
"content": prefix
}
]
chat_completion = client.chat.completions.create(
messages=messages,
model="gpt-4o-mini",
max_tokens=max_new_tokens,
temperature=0.0
)
return chat_completion.choices[0].message.content
res = gg.guide(
draft_model=openai_generate,
parser=parser,
prompt="Here's a JSON object with only string values:",
target_model=guidance.models.Transformers(
"HuggingFaceTB/SmolLM-135M", echo=False
),
max_grammar_corrections=20,
verbose=True,
)
Documentation
All calls to gg.guide
take the following arguments. When draft_model
is of type AutoModelForCausalLM
, we have a bit of extra control, hence the 'Transformers only' arguments.
Argument | Type | Description |
---|---|---|
draft_model |
Union[AutoModelForCausalLM, Callable[[str], str]] |
A transformer model or callable to use for text generation. |
tokenizer |
AutoTokenizer |
Transformers only, the tokenizer associated with the model. |
parser |
EarleyParser |
The parser used for grammar checking. |
prompt |
str |
The initial prompt for text generation. |
target_model |
guidance.models.Model |
The guidance model to use for constrained grammar correction. See guidance-ai/guidance |
seed_str |
Optional[str] |
An optional seed string to start the generation. |
max_grammar_corrections |
int |
Maximum number of grammar corrections to attempt. |
stop_at |
Collection[str] |
Collection of strings to stop generation at. |
token_healing |
Optional[bool] |
Transformers only, whether to use token healing during generation. |
top_p |
float |
Transformers only, the cumulative probability for top-p sampling. |
temperature |
float |
Transformers only, the temperature for controlling randomness in generation. |
token_lookahead |
int |
Maximum number of new tokens to generate using draft model. Essentially the $K$ parameter in speculative decoding. |
save_html |
bool |
Whether to save the generation process as HTML. |
verbose |
bool |
Whether to print verbose output. |
debug |
bool |
Whether to run in debug mode with additional checks. |
As described in the paper, one way many existing libraries achieve constrained decoding is by enforcing some constraint at each decoding timestep. For local models, it is possible to pre-process the logit masks such that this is relatively efficient. However, for closed models (think OpenAI, Anthropic, etc.), this can be 'prohitively expensive', since it would require calling the API at each timestep with the full prompt and valid continuation tokens.
Instead, this library takes an optimistic approach to constrained decoding. Autoregressive language models are only going to get better, and often times the overhead of strict, mask-driven constrained decoding isn't worth it.
For example, if we want gpt-4o to generate some SQLite query, chances are, it'll generate a valid query without any constraints.
If there is a mistake, though, we use our grammar to parse the longest prefix that abides by our grammar definition.
prediction = "SELECT * FROM students WHERE name SIMILAR TO 'Dan%';"
# Oops! `SIMILAR TO` works in PostgreSQL, but not SQLite
prefix, candidates = obtain_correction_pairs(prediction, parser)
print(prefix)
# SELECT * FROM students WHERE name
print(candidates)
# ['IN', '>', '=', 'NOT', 'BETWEEN', 'LIKE', ...]
Once we have a list of candidates, we can use our target model to select a valid continuation. In the above example, our candidates are fairly simple strings. However, our grammar may define regular expression continuations as well (e.g. (?:(?:[A-Z]|[a-z])|_)(?:(?:(?:[A-Z]|[a-z])|[0-9]|_))*
).
This is powered by the library guidance.
Once the target model has selected a valid continuation, we are free to pass the new prefix back to the draft lanugage model to complete the prediction.
selected_candidate = choose_candidate(candidates, prefix, target_model)
print(selected_candidate)
# 'LIKE'
# Now, pass back to the main model to continue its prediction from this new breakpoint
draft_model.predict(prefix + selected_candidate)
[!TIP] We borrow "Draft" and "Target" terminology from one of the original speculative decoding papers (1). However, in our case, we consider the model constrained by the grammar which generates very small bits of text to be the 'target' model, since these generations will always be accepted. The draft model, then, is the often larger model that generates unconstrained up until we tell it to stop (governed by the
token_lookahead
parameter)
Benchmarks
The below benchmarks are done on my Macbook M1, with the command python -m examples.benchmarks.run
.
They measure the tokens/sec for the respective methods to generate a JSON with exactly n string key-value pairs, using HuggingFaceTB/SmolLM-135M and the below prompt.
Here is a JSON object, with {n} keys, using only string values:\n\n```json\n
For most general usecases when using local Transformers models, I highly recommend the library transformers-CFG!
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file grammar_guide-0.0.11.tar.gz
.
File metadata
- Download URL: grammar_guide-0.0.11.tar.gz
- Upload date:
- Size: 50.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ac3c5c40dee9b835e3f7222daf57634b6f4749c38d07362c81c793252971eb4f |
|
MD5 | 1c19b10ae43ed548b6524d8bcc6f813a |
|
BLAKE2b-256 | 7ee0eaa7dd5c7160a7390b4d1e24fbbcd6ce39b01477e4d07a1c42bf3759e9d0 |
File details
Details for the file grammar_guide-0.0.11-py3-none-any.whl
.
File metadata
- Download URL: grammar_guide-0.0.11-py3-none-any.whl
- Upload date:
- Size: 49.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.1.1 CPython/3.9.20
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 110e797fb1fd45596cf77ea77d0b6c53c56dd2d439866b661fc5ff2b47074bfe |
|
MD5 | 7a32697f8c3d26ae037fe76280ea70c7 |
|
BLAKE2b-256 | 13257d43ac20bc8d3880a79c92076fe2256b1a20405746721eed973ae2814779 |