A constructive prior over small Python programs using AST-driven constrained decoding
Project description
PyPrior
A constructive prior over small Python programs using AST-driven constrained decoding.
PyPrior samples programs from a restricted Python grammar by deciding, at every step, the exact set of tokens that are legal next — so every program it produces is syntactically valid, scope-correct (no undefined variables), and bounded in size by construction. Under the hood, PyPrior uses AST-CD: the token stream is a preorder serialization of the AST.
import random
from astcd.vocab import Vocabulary
from astcd.types import DecodeConfig, Start
from astcd.decoder import sample_tokens, weighted_policy, PolicyConfig
from astcd.codec import tokens_to_ast, ast_to_source
from astcd.exec import run
cfg = DecodeConfig()
tok = Vocabulary(cfg)
rng = random.Random(2)
toks = sample_tokens(Start.FUNCTION, rng, cfg, tok, max_len=30,
policy=weighted_policy(PolicyConfig(), tok, rng))
fn = tokens_to_ast(toks, tok)
print(ast_to_source(fn))
print("output:", run(fn, (3, 4)).value)
def func(a, b):
d = a
return ((b * (d + d)) * a)
output: 72
Why
- Constrained decoding for LLM Python AST code generation.
Decoderexposes the exact legal-token mask at each step, so a model can generate only programs inside the supported AST subset. - Random Python code generation for data augmentation. The same grammar, scope, and budget constraints can sample many small, well-formed programs for training data or evaluation corpora.
- No generate-then-reject loop. Every sampled token stream decodes to syntactically valid, scope-safe restricted Python by construction.
- Safe + deterministic by construction. A small direct interpreter (no
exec/compile) runs programs under integer/step caps and returns a structured outcome. Overflow and step-limit cases are reported as data; nothing can import or do I/O.
Why not just use a Python grammar mask?
Generic constrained decoders such as SynCode, XGrammar, and Outlines are great when a normal text-token LLM already has a useful prior and you want to stop it from leaving a grammar. They are not designed to be high-quality unconditional random program generators.
For example, uniform random sampling through SynCode's Python grammar mask with a GPT-2 tokenizer can produce syntactically valid Python, but the distribution is over natural-language subword tokens. A prompt-conditioned sample can look like:
def f(x):
KCatioushaIntroductioneeAddedRushymmWrittennyderescalALTHpecAccess713EyeensitiveenderedtriggerRahredditmenuappedicheixtpossiblybandJJwatersORD
That parses as Python, but it is just a bare expression using an undefined name. The grammar mask did its job: it controlled syntax. It did not provide a program prior, scope safety, AST shape control, or execution semantics.
AST-CD samples directly in a small AST vocabulary, so even uniform random sampling produces structured, scope-safe programs:
def func(a, b):
d = (3 * 2)
return 3
return 5
e = (3 - ((6 + 9) - 2))
| Approach | Good at | Missing for AST-CD's use case |
|---|---|---|
| Text-token grammar masks | Keeping an LLM's text output syntactically valid | Scope tracking, AST-token output, small controlled vocabulary, bounded AST size/depth, direct execution model |
| AST-CD | Random code augmentation and AST-vocab constrained decoding | Full Python coverage, arbitrary libraries/imports, unconstrained human-style code |
This is the core distinction: text-level constrained decoders constrain a language model; AST-CD defines the language, the vocabulary, the decoder state, and the executable subset together.
Install
Install directly from the GitHub repo:
pip install pyprior # core (pure standard library, zero dependencies)
pip install "pyprior[gif]" # + Pillow, to export the decode animation as a GIF
The constrained-decoding API
The core object is a Decoder: ask it what's legal, advance with any of those
tokens, repeat. That seam is where a random sampler, a weighted prior, or a neural
model all plug in identically.
from astcd.decoder import Decoder
d = Decoder(Start.FUNCTION, cfg, tok, max_len=30)
while not d.is_complete():
legal = d.next_legal_tokens() # the mask: grammar ∩ scope ∩ depth/size guards
d.advance(sample(legal)) # pick any legal token (rng / weighted / model)
fn = d.function()
next_legal_tokens() is intersected from four sources: the grammar's FIRST set,
the in-scope variable names (definite assignment), the depth/length guards, and a
min-cost-to-complete check that guarantees the program closes within max_len.
It is never empty until the program is complete.
Planned ideal API
The current low-level API is intentionally explicit: build a Vocabulary, create
a Decoder, inspect the legal-token mask, and advance one token at a time. The
planned public API should keep that same mental model:
state -> legal mask -> choose token -> advance state
Random program generation should not be a separate subsystem. It should be the
same constrained decoder driven by a random sampler. There should be no separate
scalar generation API: a single sample is represented as a one-item input batch
with n=1. When n > 1, each input state is cloned into n independent
continuations.
from astcd import decode, detokenize, init_states, tokenize
from astcd.api import DecodeRequest
from astcd.samplers import UniformSampler
from astcd.types import DecodeConfig, Start
cfg = DecodeConfig(max_block_stmts=1)
# Free-form constrained decoding. Long term, Start.MODULE should be the default root.
states = init_states([DecodeRequest(start=Start.MODULE)], cfg=cfg)
batch = decode(
states,
sampler=UniformSampler(seed=0),
cfg=cfg,
n=8,
max_total_tokens=300,
max_new_tokens=200,
stop_at="module_boundary",
)
# Continue the same batch without recomputing prompt/decoder state.
batch = decode(
batch.states,
sampler=UniformSampler(seed=1),
n=1,
max_new_tokens=100,
stop_at="module_boundary",
)
# A single sample is a one-item batch.
one = decode(
init_states([DecodeRequest(start=Start.FUNCTION)], cfg=cfg),
sampler=UniformSampler(seed=2),
n=1,
max_new_tokens=200,
)
print(one.results[0].source)
# Prompt-conditioned decoding uses the same state-resume path.
tokenized = tokenize(["def func(a):", "def func(a, b):"], cfg=cfg, partial=True)
prompt_states = [result.state for result in tokenized.results if result.ok]
batch = decode(
prompt_states,
sampler=UniformSampler(seed=3),
n=4, # four completions per prompt state
max_new_tokens=200,
)
# Convert token batches back to Python source strings.
sources = detokenize([result.tokens for result in batch.results], cfg=cfg)
The planned constrained-decoding surface is batch-only:
init_states(requests=[...], cfg=...): create fresh decoder states from starts and optional AST-CD token prefixes.legal_masks(states): return the legal next-token mask for each active state.advance(states, token_ids): consume one AST-CD token per state and return the next states.decode(states, sampler=..., n=1, ...): run the mask/sample/advance loop with a sampler. This is a convenience driver, not a different algorithm.nis a per-input fanout, not a scalar generation mode.tokenize(sources=[...], ...): convert supported Python source strings, or supported incomplete prompts such asdef func(a):, into AST-CD token ids and resumableDecodeStateobjects.detokenize(token_batches=[...], ...): convert batches of AST-CD token ids back to Python source strings. The lower-level pieces already exist today astokens_to_source(...),tokens_to_ast(...), andast_to_source(...).
For len(states) == B, decode(..., n=N) returns B * N results and states,
ordered by (input_index, sample_index).
The planned result/state shape should make continuation, visualization, and training data export straightforward:
@dataclass(frozen=True, slots=True)
class DecodeBatch:
results: tuple[DecodeResult, ...]
states: tuple[DecodeState, ...]
input_count: int
n: int
sampler_state: object | None = None
@dataclass(frozen=True, slots=True)
class DecodeRequest:
start: Start = Start.MODULE
prefix_tokens: tuple[int, ...] = ()
@dataclass(frozen=True, slots=True)
class DecodeResult:
input_index: int
sample_index: int
tokens: tuple[int, ...]
new_tokens: tuple[int, ...]
source: str
complete: bool
stop_reason: str
trace: tuple[DecodeStep, ...] = ()
padded_tokens: tuple[int, ...] | None = None
@dataclass(frozen=True, slots=True)
class DecodeState:
start: Start
tokens: tuple[int, ...]
complete: bool
token_budget_remaining: int
# plus an internal decoder snapshot
tokenize(...) should be useful both for prompt conditioning and for measuring
how much human-written Python data fits the supported grammar. It should return
structured feedback instead of just raising a string error:
token_batch = tokenize([source], cfg=cfg, partial=True)
result = token_batch.results[0]
if result.ok:
state = result.state # pass this to decode(...) to continue
else:
print(result.error.stage) # "python_parse", "unsupported_syntax",
# "name_normalization", "vocab_encode", ...
print(result.error.message)
print(result.error.line, result.error.column)
print(result.error.node_type) # e.g. "Import", "Call", "Constant[str]"
print(result.tokens) # longest valid prefix, if available
The result shape should make corpus auditing straightforward:
@dataclass(frozen=True, slots=True)
class TokenizeBatch:
results: tuple[TokenizeResult, ...]
@dataclass(frozen=True, slots=True)
class TokenizeResult:
ok: bool
tokens: tuple[int, ...]
start: Start
state: DecodeState | None
complete: bool
normalized_source: str | None
error: TokenizeError | None
For LLM integrations, assume the model emits AST-CD vocabulary token ids. The core constrained-decoding API should expose batched masking and advancement directly:
from astcd.constraint import advance, legal_masks
tokenized = tokenize(["def func(a):", "def func(a, b):"], cfg=cfg, partial=True)
states = [result.state for result in tokenized.results if result.ok]
while not all(s.complete for s in states):
masks = legal_masks(states) # shape: [batch, astcd_vocab_size]
next_tokens = astcd_vocab_model.sample(masks)
states = advance(states, next_tokens)
decode(...) is then just a sampler-driven loop built on the same primitives. It
should support optional padding for training data. By default it returns only the
generated sequence. If pad=True, it appends a configured padding token up to
the requested fixed length:
batch = decode(
init_states([DecodeRequest(start=Start.FUNCTION)], cfg=cfg),
sampler=UniformSampler(seed=0),
n=32,
max_new_tokens=200,
pad=True,
)
assert all(len(r.padded_tokens) == 200 for r in batch.results)
Adapters for model-serving libraries should be thin wrappers around
legal_masks(...) and advance(...).
For Hugging Face Transformers, the adapter should be a logits processor for models whose token ids already correspond to AST-CD vocabulary ids:
from astcd.integrations.transformers import ASTCDLogitsProcessor
states = init_states([DecodeRequest(start=Start.MODULE)], cfg=cfg)
processor = ASTCDLogitsProcessor(states=states, cfg=cfg)
outputs = model.generate(
input_ids=processor.input_ids(),
max_new_tokens=200,
logits_processor=[processor],
)
For vLLM, the adapter should use vLLM's custom logits-processor mechanism rather
than structured_outputs. vLLM's built-in structured-output formats
(choice, regex, json, grammar, and structural_tag) are useful for
normal text/BPE generation. A static allowed_token_ids list is also
insufficient. AST-CD needs a dynamic stateful mask over the AST-CD vocabulary:
from vllm import LLM, SamplingParams
from astcd.integrations.vllm import ASTCDLogitsProcessor
llm = LLM(
model="your-astcd-vocab-model",
logits_processors=[ASTCDLogitsProcessor],
)
params = SamplingParams(
max_tokens=200,
n=8,
extra_args={
"astcd_cfg": cfg.to_json(),
"astcd_state": states[0].to_json(),
},
)
outputs = llm.generate(
prompt_token_ids=[list(states[0].tokens)],
sampling_params=params,
)
Generic Python constrained-decoding libraries such as SynCode, XGrammar, and Outlines are worth using as baselines or adapters when the model emits normal Python text tokens. AST-CD is still useful as a separate package because it constrains an AST-token language, tracks scope, enforces bounded completion, supports prompt tokenization into decoder state, and can decode directly to an executable restricted Python AST.
Visualization should be implemented as helpers that consume decode batches, not as special modes inside the sampler:
from astcd.display import show, animate, save_gif, dump_json
show(batch, index=0)
animate(batch, index=0)
save_gif(batch, "decode.gif", index=0)
dump_json(batch, "samples.json")
Longer term, Start.MODULE should become the default public root so ordinary
files can be modeled as a module-level block. Specialized starts such as
Start.FUNCTION, Start.BLOCK, and Start.EXPR should remain useful for tests,
focused generation, and continuations inside a known syntactic context.
What it generates
A small, bounded subset of Python over integers:
def func(a, b): # arity 1..N
d = a * b + 1 # assignments to fresh / existing locals
if d < 0: # if / else
return 0 # return is a statement (early returns OK)
return d # ...or fall through to None
Optional features (off by default) via DecodeConfig: for v in range(e): loops
(allow_loops), and the size/shape knobs arities, max_block_stmts,
max_expr_depth, max_control_depth, and the integer value_range /
const_range.
Watch it decode
The most fun way to understand it — animate a program assembling itself one legal
token at a time, holes (<expr>, <stmt>, ...) filling in:
python -m astcd.viz # animate a random function
python -m astcd.viz --seed 7 --delay 0.3
python -m astcd.viz --uniform # flat policy (more complex programs)
python -m astcd.viz --loops --max-control-depth 3
python -m astcd.viz --play # interactive: YOU pick each token from the menu
python -m astcd.viz --gif decode.gif # export the animation
Status
Early and experimental, but the core is complete and tested (pytest):
grammar / vocabulary / scope / interpreter / codec round-trip / decoder soundness /
samplers / visualization.
License
MIT — see LICENSE.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pyprior-0.0.1.tar.gz.
File metadata
- Download URL: pyprior-0.0.1.tar.gz
- Upload date:
- Size: 32.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2c3c0b54140f70f4373b222aa660735ce7b98db825b47b590d19237a7c31287f
|
|
| MD5 |
cca439476e3897dc565bcc59980a3044
|
|
| BLAKE2b-256 |
f79f4ccabfa2a378efb78eb93ad531d72693d9034feb73a8b36a18e19233f821
|
File details
Details for the file pyprior-0.0.1-py3-none-any.whl.
File metadata
- Download URL: pyprior-0.0.1-py3-none-any.whl
- Upload date:
- Size: 37.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.10.13
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c2f17164687a8b0dea5d86455a48ab46d3691ddf6a83c19324ea006f2a805e34
|
|
| MD5 |
9098186b7c7dcdb39656aaa98550063c
|
|
| BLAKE2b-256 |
5e3c8d9ae6cc807a007fb98815add219e37b617c2bf36b33f0bcd3d040481e6e
|