Skip to main content

Speculative grammar backtracking algorithm for LLM decoding conforming to some lark context-free grammar (CFG)

Project description

Grammar Guide

Speculative grammar backtracking algorithm to perform grammar-constrained decoding with any text generation function (OpenAI, Anthropic, etc.)

This repo is a slightly modified implementation of the decoding mechanism described in Section 3.2 of Grammar Prompting for Domain-Specific Language Generation with Large Language Models by @berlino. I refer to the general algorithm as speculative grammar backtracking.

It is a form of constrained decoding, and can be used to guide even proprietary, black-box LLM APIs according to some context-free grammar. It's rooted in the idea that as LLMs get better, not all steps of the decoding process need to apply a strict logit mask - like a good teacher, we let the student give an answer, and only intervene and correct when necessary.

When using local Transformer models, we can efficiently backtrack the KV cache according to a given Lark CFG. Below is a benchmark showing tokens/sec when generating a JSON object with [10, 20, 30, 40] keys using HuggingFaceTB/SmolLM-135M on my Macbook M1. naive-vs-grammar-guide

Features

  • Compatible with any text generation function
    • OpenAI, Anthropic etc. - as long as you can provide some generate(prompt: str) -> str function!
  • Efficient re-use of KV cache for all CausalLM Transformer models
    • Optimistic, speculative decoding = no need to manually update to support new tokenizers
  • Visualization and logging of grammar corrections
  • Token healing to ensure high probability continuations
pip install grammar-guide

Examples

With Transformer Models

When using HuggingFace Transformer models, we get an extra speed boost by leveraging efficient caching and backtracking of the KV cache. When a grammar correction is made, we backtrack to the state of the KV cache aligned to the longest prefix that is valid under our Lark context-free grammar.

from transformers import AutoModelForCausalLM, AutoTokenizer
import guidance

import grammar_guide as gg

model_name_or_path = "HuggingFaceTB/SmolLM-135M"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
parser = gg.load_parser("../grammars/json_five_values_string_only.lark")

res = gg.guide(
  draft_model=model,
  tokenizer=tokenizer,
  parser=parser,
  prompt="Here's a JSON object with only string values:",
  target_model=guidance.models.Transformers(
      model_name_or_path, echo=False
  ),
  stop_at=['```'],
  token_lookahead=20,
  max_grammar_corrections=20,
  temperature=0.0
)

In the visualization below:

  • Green highlights = text generated by the draft model
  • Blue highlights = candidates selected by our target model
  • Red text = backtracked text that violated the context-free grammar
  • Orange text = tokens that were fed through the token healing logits processor

jupyer-visualization

With General API-based Providers

import os
from openai import OpenAI
import guidance

import grammar_guide as gg

client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
parser = gg.load_parser("../grammars/json_five_values_string_only.lark")

# Define our core completion predict function
# This just needs to follow the `fn(s: str) -> str` contract
#   so we can use any black-box API provider.
def openai_generate(prefix: str, prompt: str, max_new_tokens: int) -> str:
    messages = [
        {
            "role": "system",
            "content": prompt
        }
    ]
    if prefix:
        messages += [
            {
                "role": "assistant",
                "content": prefix
            }
        ]
    chat_completion = client.chat.completions.create(
        messages=messages,
        model="gpt-4o-mini",
        max_tokens=max_new_tokens,
        temperature=0.0
    )
    return chat_completion.choices[0].message.content

res = gg.guide(
    draft_model=openai_generate,
    parser=parser,
    prompt="Here's a JSON object with only string values:",
    target_model=guidance.models.Transformers(
        "HuggingFaceTB/SmolLM-135M", echo=False
    ),
    max_grammar_corrections=20,
    verbose=True,
)

Documentation

All calls to gg.guide take the following arguments. When draft_model is of type AutoModelForCausalLM, we have a bit of extra control, hence the 'Transformers only' arguments.

Argument Type Description
draft_model Union[AutoModelForCausalLM, Callable[[str], str]] A transformer model or callable to use for text generation.
tokenizer AutoTokenizer Transformers only, the tokenizer associated with the model.
parser EarleyParser The parser used for grammar checking.
prompt str The initial prompt for text generation.
target_model guidance.models.Model The guidance model to use for constrained grammar correction. See guidance-ai/guidance
seed_str Optional[str] An optional seed string to start the generation.
max_grammar_corrections int Maximum number of grammar corrections to attempt.
stop_at Collection[str] Collection of strings to stop generation at.
token_healing Optional[bool] Transformers only, whether to use token healing during generation.
top_p float Transformers only, the cumulative probability for top-p sampling.
temperature float Transformers only, the temperature for controlling randomness in generation.
token_lookahead int Maximum number of new tokens to generate using draft model. Essentially the $K$ parameter in speculative decoding.
save_html bool Whether to save the generation process as HTML.
verbose bool Whether to print verbose output.
debug bool Whether to run in debug mode with additional checks.

As described in the paper, one way many existing libraries achieve constrained decoding is by enforcing some constraint at each decoding timestep. For local models, it is possible to pre-process the logit masks such that this is relatively efficient. However, for closed models (think OpenAI, Anthropic, etc.), this can be 'prohitively expensive', since it would require calling the API at each timestep with the full prompt and valid continuation tokens.

Instead, this library takes an optimistic approach to constrained decoding. Autoregressive language models are only going to get better, and often times the overhead of strict, mask-driven constrained decoding isn't worth it.

For example, if we want gpt-4o to generate some SQLite query, chances are, it'll generate a valid query without any constraints.

If there is a mistake, though, we use our grammar to parse the longest prefix that abides by our grammar definition.

prediction = "SELECT * FROM students WHERE name SIMILAR TO 'Dan%';"
# Oops! `SIMILAR TO` works in PostgreSQL, but not SQLite
prefix, candidates = obtain_correction_pairs(prediction, parser)
print(prefix)
# SELECT * FROM students WHERE name
print(candidates)
# ['IN', '>', '=', 'NOT', 'BETWEEN', 'LIKE', ...]

Once we have a list of candidates, we can use our target model to select a valid continuation. In the above example, our candidates are fairly simple strings. However, our grammar may define regular expression continuations as well (e.g. (?:(?:[A-Z]|[a-z])|_)(?:(?:(?:[A-Z]|[a-z])|[0-9]|_))*). This is powered by the library guidance.

Once the target model has selected a valid continuation, we are free to pass the new prefix back to the draft lanugage model to complete the prediction.

selected_candidate = choose_candidate(candidates, prefix, target_model)
print(selected_candidate)
# 'LIKE'
# Now, pass back to the main model to continue its prediction from this new breakpoint
draft_model.predict(prefix + selected_candidate)

[!TIP] We borrow "Draft" and "Target" terminology from one of the original speculative decoding papers (1). However, in our case, we consider the model constrained by the grammar which generates very small bits of text to be the 'target' model, since these generations will always be accepted. The draft model, then, is the often larger model that generates unconstrained up until we tell it to stop (governed by the token_lookahead parameter)

Benchmarks

The below benchmarks are done on my Macbook M1, with the command python -m examples.benchmarks.run.

They measure the tokens/sec for the respective methods to generate a JSON with exactly n string key-value pairs, using HuggingFaceTB/SmolLM-135M and the below prompt.

Here is a JSON object, with {n} keys, using only string values:\n\n```json\n

For most general usecases when using local Transformers models, I highly recommend the library transformers-CFG!

runtime-benchmark

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

grammar_guide-0.0.11.tar.gz (50.6 kB view details)

Uploaded Source

Built Distribution

grammar_guide-0.0.11-py3-none-any.whl (49.4 kB view details)

Uploaded Python 3

File details

Details for the file grammar_guide-0.0.11.tar.gz.

File metadata

  • Download URL: grammar_guide-0.0.11.tar.gz
  • Upload date:
  • Size: 50.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.9.20

File hashes

Hashes for grammar_guide-0.0.11.tar.gz
Algorithm Hash digest
SHA256 ac3c5c40dee9b835e3f7222daf57634b6f4749c38d07362c81c793252971eb4f
MD5 1c19b10ae43ed548b6524d8bcc6f813a
BLAKE2b-256 7ee0eaa7dd5c7160a7390b4d1e24fbbcd6ce39b01477e4d07a1c42bf3759e9d0

See more details on using hashes here.

File details

Details for the file grammar_guide-0.0.11-py3-none-any.whl.

File metadata

File hashes

Hashes for grammar_guide-0.0.11-py3-none-any.whl
Algorithm Hash digest
SHA256 110e797fb1fd45596cf77ea77d0b6c53c56dd2d439866b661fc5ff2b47074bfe
MD5 7a32697f8c3d26ae037fe76280ea70c7
BLAKE2b-256 13257d43ac20bc8d3880a79c92076fe2256b1a20405746721eed973ae2814779

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page