Skip to main content

A constructive prior over small Python programs using AST-driven constrained decoding

Project description

PyPrior

A constructive prior over small Python programs using AST-driven constrained decoding.

A program with a for-loop assembling itself one legal token at a time

PyPrior samples programs from a restricted Python grammar by deciding, at every step, the exact set of tokens that are legal next — so every program it produces is syntactically valid, scope-correct (no undefined variables), and bounded in size by construction. Under the hood, PyPrior uses AST-CD: the token stream is a preorder serialization of the AST.

import random
from astcd.vocab import Vocabulary
from astcd.types import DecodeConfig, Start
from astcd.decoder import sample_tokens, weighted_policy, PolicyConfig
from astcd.codec import tokens_to_ast, ast_to_source
from astcd.exec import run

cfg = DecodeConfig()
tok = Vocabulary(cfg)
rng = random.Random(2)

toks = sample_tokens(Start.FUNCTION, rng, cfg, tok, max_len=30,
                     policy=weighted_policy(PolicyConfig(), tok, rng))
fn = tokens_to_ast(toks, tok)
print(ast_to_source(fn))
print("output:", run(fn, (3, 4)).value)
def func(a, b):
    d = a
    return ((b * (d + d)) * a)
output: 72

Why

  • Constrained decoding for LLM Python AST code generation. Decoder exposes the exact legal-token mask at each step, so a model can generate only programs inside the supported AST subset.
  • Random Python code generation for data augmentation. The same grammar, scope, and budget constraints can sample many small, well-formed programs for training data or evaluation corpora.
  • No generate-then-reject loop. Every sampled token stream decodes to syntactically valid, scope-safe restricted Python by construction.
  • Safe + deterministic by construction. A small direct interpreter (no exec/compile) runs programs under integer/step caps and returns a structured outcome. Overflow and step-limit cases are reported as data; nothing can import or do I/O.

Why not just use a Python grammar mask?

Generic constrained decoders such as SynCode, XGrammar, and Outlines are great when a normal text-token LLM already has a useful prior and you want to stop it from leaving a grammar. They are not designed to be high-quality unconditional random program generators.

For example, uniform random sampling through SynCode's Python grammar mask with a GPT-2 tokenizer can produce syntactically valid Python, but the distribution is over natural-language subword tokens. A prompt-conditioned sample can look like:

def f(x):
     KCatioushaIntroductioneeAddedRushymmWrittennyderescalALTHpecAccess713EyeensitiveenderedtriggerRahredditmenuappedicheixtpossiblybandJJwatersORD

That parses as Python, but it is just a bare expression using an undefined name. The grammar mask did its job: it controlled syntax. It did not provide a program prior, scope safety, AST shape control, or execution semantics.

AST-CD samples directly in a small AST vocabulary, so even uniform random sampling produces structured, scope-safe programs:

def func(a, b):
    d = (3 * 2)
    return 3
    return 5
    e = (3 - ((6 + 9) - 2))
Approach Good at Missing for AST-CD's use case
Text-token grammar masks Keeping an LLM's text output syntactically valid Scope tracking, AST-token output, small controlled vocabulary, bounded AST size/depth, direct execution model
AST-CD Random code augmentation and AST-vocab constrained decoding Full Python coverage, arbitrary libraries/imports, unconstrained human-style code

This is the core distinction: text-level constrained decoders constrain a language model; AST-CD defines the language, the vocabulary, the decoder state, and the executable subset together.

Install

Install directly from the GitHub repo:

pip install pyprior            # core (pure standard library, zero dependencies)
pip install "pyprior[gif]"     # + Pillow, to export the decode animation as a GIF

The constrained-decoding API

The core object is a Decoder: ask it what's legal, advance with any of those tokens, repeat. That seam is where a random sampler, a weighted prior, or a neural model all plug in identically.

from astcd.decoder import Decoder

d = Decoder(Start.FUNCTION, cfg, tok, max_len=30)
while not d.is_complete():
    legal = d.next_legal_tokens()     # the mask: grammar ∩ scope ∩ depth/size guards
    d.advance(sample(legal))          # pick any legal token (rng / weighted / model)
fn = d.function()

next_legal_tokens() is intersected from four sources: the grammar's FIRST set, the in-scope variable names (definite assignment), the depth/length guards, and a min-cost-to-complete check that guarantees the program closes within max_len. It is never empty until the program is complete.

Planned ideal API

The current low-level API is intentionally explicit: build a Vocabulary, create a Decoder, inspect the legal-token mask, and advance one token at a time. The planned public API should keep that same mental model:

state -> legal mask -> choose token -> advance state

Random program generation should not be a separate subsystem. It should be the same constrained decoder driven by a random sampler. There should be no separate scalar generation API: a single sample is represented as a one-item input batch with n=1. When n > 1, each input state is cloned into n independent continuations.

from astcd import decode, detokenize, init_states, tokenize
from astcd.api import DecodeRequest
from astcd.samplers import UniformSampler
from astcd.types import DecodeConfig, Start

cfg = DecodeConfig(max_block_stmts=1)

# Free-form constrained decoding. Long term, Start.MODULE should be the default root.
states = init_states([DecodeRequest(start=Start.MODULE)], cfg=cfg)
batch = decode(
    states,
    sampler=UniformSampler(seed=0),
    cfg=cfg,
    n=8,
    max_total_tokens=300,
    max_new_tokens=200,
    stop_at="module_boundary",
)

# Continue the same batch without recomputing prompt/decoder state.
batch = decode(
    batch.states,
    sampler=UniformSampler(seed=1),
    n=1,
    max_new_tokens=100,
    stop_at="module_boundary",
)

# A single sample is a one-item batch.
one = decode(
    init_states([DecodeRequest(start=Start.FUNCTION)], cfg=cfg),
    sampler=UniformSampler(seed=2),
    n=1,
    max_new_tokens=200,
)
print(one.results[0].source)

# Prompt-conditioned decoding uses the same state-resume path.
tokenized = tokenize(["def func(a):", "def func(a, b):"], cfg=cfg, partial=True)
prompt_states = [result.state for result in tokenized.results if result.ok]

batch = decode(
    prompt_states,
    sampler=UniformSampler(seed=3),
    n=4,  # four completions per prompt state
    max_new_tokens=200,
)

# Convert token batches back to Python source strings.
sources = detokenize([result.tokens for result in batch.results], cfg=cfg)

The planned constrained-decoding surface is batch-only:

  • init_states(requests=[...], cfg=...): create fresh decoder states from starts and optional AST-CD token prefixes.
  • legal_masks(states): return the legal next-token mask for each active state.
  • advance(states, token_ids): consume one AST-CD token per state and return the next states.
  • decode(states, sampler=..., n=1, ...): run the mask/sample/advance loop with a sampler. This is a convenience driver, not a different algorithm. n is a per-input fanout, not a scalar generation mode.
  • tokenize(sources=[...], ...): convert supported Python source strings, or supported incomplete prompts such as def func(a):, into AST-CD token ids and resumable DecodeState objects.
  • detokenize(token_batches=[...], ...): convert batches of AST-CD token ids back to Python source strings. The lower-level pieces already exist today as tokens_to_source(...), tokens_to_ast(...), and ast_to_source(...).

For len(states) == B, decode(..., n=N) returns B * N results and states, ordered by (input_index, sample_index).

The planned result/state shape should make continuation, visualization, and training data export straightforward:

@dataclass(frozen=True, slots=True)
class DecodeBatch:
    results: tuple[DecodeResult, ...]
    states: tuple[DecodeState, ...]
    input_count: int
    n: int
    sampler_state: object | None = None

@dataclass(frozen=True, slots=True)
class DecodeRequest:
    start: Start = Start.MODULE
    prefix_tokens: tuple[int, ...] = ()

@dataclass(frozen=True, slots=True)
class DecodeResult:
    input_index: int
    sample_index: int
    tokens: tuple[int, ...]
    new_tokens: tuple[int, ...]
    source: str
    complete: bool
    stop_reason: str
    trace: tuple[DecodeStep, ...] = ()
    padded_tokens: tuple[int, ...] | None = None

@dataclass(frozen=True, slots=True)
class DecodeState:
    start: Start
    tokens: tuple[int, ...]
    complete: bool
    token_budget_remaining: int
    # plus an internal decoder snapshot

tokenize(...) should be useful both for prompt conditioning and for measuring how much human-written Python data fits the supported grammar. It should return structured feedback instead of just raising a string error:

token_batch = tokenize([source], cfg=cfg, partial=True)
result = token_batch.results[0]

if result.ok:
    state = result.state              # pass this to decode(...) to continue
else:
    print(result.error.stage)       # "python_parse", "unsupported_syntax",
                                    # "name_normalization", "vocab_encode", ...
    print(result.error.message)
    print(result.error.line, result.error.column)
    print(result.error.node_type)   # e.g. "Import", "Call", "Constant[str]"
    print(result.tokens)            # longest valid prefix, if available

The result shape should make corpus auditing straightforward:

@dataclass(frozen=True, slots=True)
class TokenizeBatch:
    results: tuple[TokenizeResult, ...]

@dataclass(frozen=True, slots=True)
class TokenizeResult:
    ok: bool
    tokens: tuple[int, ...]
    start: Start
    state: DecodeState | None
    complete: bool
    normalized_source: str | None
    error: TokenizeError | None

For LLM integrations, assume the model emits AST-CD vocabulary token ids. The core constrained-decoding API should expose batched masking and advancement directly:

from astcd.constraint import advance, legal_masks

tokenized = tokenize(["def func(a):", "def func(a, b):"], cfg=cfg, partial=True)
states = [result.state for result in tokenized.results if result.ok]

while not all(s.complete for s in states):
    masks = legal_masks(states)          # shape: [batch, astcd_vocab_size]
    next_tokens = astcd_vocab_model.sample(masks)
    states = advance(states, next_tokens)

decode(...) is then just a sampler-driven loop built on the same primitives. It should support optional padding for training data. By default it returns only the generated sequence. If pad=True, it appends a configured padding token up to the requested fixed length:

batch = decode(
    init_states([DecodeRequest(start=Start.FUNCTION)], cfg=cfg),
    sampler=UniformSampler(seed=0),
    n=32,
    max_new_tokens=200,
    pad=True,
)
assert all(len(r.padded_tokens) == 200 for r in batch.results)

Adapters for model-serving libraries should be thin wrappers around legal_masks(...) and advance(...).

For Hugging Face Transformers, the adapter should be a logits processor for models whose token ids already correspond to AST-CD vocabulary ids:

from astcd.integrations.transformers import ASTCDLogitsProcessor

states = init_states([DecodeRequest(start=Start.MODULE)], cfg=cfg)
processor = ASTCDLogitsProcessor(states=states, cfg=cfg)

outputs = model.generate(
    input_ids=processor.input_ids(),
    max_new_tokens=200,
    logits_processor=[processor],
)

For vLLM, the adapter should use vLLM's custom logits-processor mechanism rather than structured_outputs. vLLM's built-in structured-output formats (choice, regex, json, grammar, and structural_tag) are useful for normal text/BPE generation. A static allowed_token_ids list is also insufficient. AST-CD needs a dynamic stateful mask over the AST-CD vocabulary:

from vllm import LLM, SamplingParams
from astcd.integrations.vllm import ASTCDLogitsProcessor

llm = LLM(
    model="your-astcd-vocab-model",
    logits_processors=[ASTCDLogitsProcessor],
)

params = SamplingParams(
    max_tokens=200,
    n=8,
    extra_args={
        "astcd_cfg": cfg.to_json(),
        "astcd_state": states[0].to_json(),
    },
)

outputs = llm.generate(
    prompt_token_ids=[list(states[0].tokens)],
    sampling_params=params,
)

Generic Python constrained-decoding libraries such as SynCode, XGrammar, and Outlines are worth using as baselines or adapters when the model emits normal Python text tokens. AST-CD is still useful as a separate package because it constrains an AST-token language, tracks scope, enforces bounded completion, supports prompt tokenization into decoder state, and can decode directly to an executable restricted Python AST.

Visualization should be implemented as helpers that consume decode batches, not as special modes inside the sampler:

from astcd.display import show, animate, save_gif, dump_json

show(batch, index=0)
animate(batch, index=0)
save_gif(batch, "decode.gif", index=0)
dump_json(batch, "samples.json")

Longer term, Start.MODULE should become the default public root so ordinary files can be modeled as a module-level block. Specialized starts such as Start.FUNCTION, Start.BLOCK, and Start.EXPR should remain useful for tests, focused generation, and continuations inside a known syntactic context.

What it generates

A small, bounded subset of Python over integers:

def func(a, b):           # arity 1..N
    d = a * b + 1         # assignments to fresh / existing locals
    if d < 0:             # if / else
        return 0          # return is a statement (early returns OK)
    return d              # ...or fall through to None

Optional features (off by default) via DecodeConfig: for v in range(e): loops (allow_loops), and the size/shape knobs arities, max_block_stmts, max_expr_depth, max_control_depth, and the integer value_range / const_range.

Watch it decode

The most fun way to understand it — animate a program assembling itself one legal token at a time, holes (<expr>, <stmt>, ...) filling in:

python -m astcd.viz                     # animate a random function
python -m astcd.viz --seed 7 --delay 0.3
python -m astcd.viz --uniform           # flat policy (more complex programs)
python -m astcd.viz --loops --max-control-depth 3
python -m astcd.viz --play              # interactive: YOU pick each token from the menu
python -m astcd.viz --gif decode.gif    # export the animation

Status

Early and experimental, but the core is complete and tested (pytest): grammar / vocabulary / scope / interpreter / codec round-trip / decoder soundness / samplers / visualization.

License

MIT — see LICENSE.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pyprior-0.0.1.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pyprior-0.0.1-py3-none-any.whl (37.8 kB view details)

Uploaded Python 3

File details

Details for the file pyprior-0.0.1.tar.gz.

File metadata

  • Download URL: pyprior-0.0.1.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.13

File hashes

Hashes for pyprior-0.0.1.tar.gz
Algorithm Hash digest
SHA256 2c3c0b54140f70f4373b222aa660735ce7b98db825b47b590d19237a7c31287f
MD5 cca439476e3897dc565bcc59980a3044
BLAKE2b-256 f79f4ccabfa2a378efb78eb93ad531d72693d9034feb73a8b36a18e19233f821

See more details on using hashes here.

File details

Details for the file pyprior-0.0.1-py3-none-any.whl.

File metadata

  • Download URL: pyprior-0.0.1-py3-none-any.whl
  • Upload date:
  • Size: 37.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.13

File hashes

Hashes for pyprior-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 c2f17164687a8b0dea5d86455a48ab46d3691ddf6a83c19324ea006f2a805e34
MD5 9098186b7c7dcdb39656aaa98550063c
BLAKE2b-256 5e3c8d9ae6cc807a007fb98815add219e37b617c2bf36b33f0bcd3d040481e6e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page