Skip to main content

DAG-guided discrete diffusion language models for reasoning

Project description

dLLM-Reason

DAG-Guided Discrete Diffusion Language Models for Reasoning

PyPI version License: MIT CI

Overview

dLLM-Reason is a research framework that enhances reasoning in discrete diffusion language models (dLLMs) by controlling the token unmasking order via DAG (Directed Acyclic Graph) topological structures.

Core idea: dLLMs generate text by iteratively unmasking tokens. We impose a DAG on unmasking order -- edges encode reasoning dependencies -- so prerequisite steps are generated before downstream conclusions.

Model Layer          Scheduler Layer         DAG Layer
(what to predict) <-> (where to unmask) <-> (dependency structure)
MDLM|SEDD|D3PM|LLaDA   DAGScheduler          TokenDAG + Templates

Installation

pip install dllm-reason

Or install from GitHub (latest dev):

pip install "git+https://github.com/BDeMo/dLLM_Reason.git"

With optional extras (FAISS, sentence-transformers, dev tools):

pip install "dllm-reason[dev,library]"

For development (editable install):

git clone https://github.com/BDeMo/dLLM_Reason.git
cd dLLM_Reason
pip install -e ".[dev,library]"

After installation the following CLI commands are available globally:

Command Equivalent
dllm-eval-dags python scripts/eval_dags.py
dllm-train python scripts/train.py
dllm-eval python scripts/evaluate.py
dllm-search python scripts/search_dag.py
dllm-viz python scripts/visualize_dag.py

Quick Start

# Download models & datasets
python scripts/download_models.py              # -> checkpoints/llada-instruct/
python scripts/download_datasets.py            # -> datasets/

# HF mirror (China)
python scripts/download_models.py --mirror https://hf-mirror.com
python scripts/download_datasets.py --mirror https://hf-mirror.com

Usage

All parameters live in configs/eval_default.yaml. CLI flags always override the config.

1. Default run (LLaDA + confidence scheduler)

bash scripts/run_eval.sh
# pass any CLI overrides after the script:
bash scripts/run_eval.sh --benchmarks mbpp --num_samples 50
bash scripts/run_eval.sh --dags cot skeleton --num_steps 64
bash scripts/run_eval.sh --verbose_errors

Results are written to results/eval_<timestamp>/.

2. Per-strategy scripts

scripts/runs/ contains one script per unmasking strategy:

bash scripts/runs/confidence.sh    # highest-confidence first (LLaDA default)
bash scripts/runs/random.sh        # uniform random
bash scripts/runs/linear.sh        # left-to-right
bash scripts/runs/cot.sh           # Chain-of-Thought DAG
bash scripts/runs/skeleton.sh      # Skeleton-then-Detail DAG
bash scripts/runs/bidirectional.sh # bidirectional DAG
bash scripts/runs/answer_first.sh  # answer region first
bash scripts/runs/all_strategies.sh  # all 8 strategies in one run

All scripts pass extra args through to eval_dags.py:

bash scripts/runs/cot.sh --benchmarks mbpp humaneval --num_samples 100 --cot_steps 6

3. Direct CLI

python scripts/eval_dags.py \
    --dags confidence cot skeleton \
    --benchmarks mbpp humaneval \
    --num_steps 64 --temperature 0.5 \
    --num_samples 100 \
    --output_dir results/my_run

4. Config file

Edit configs/eval_default.yaml to change defaults:

model:
  model_id: "checkpoints/llada-instruct"
  torch_dtype: "bfloat16"        # bfloat16 | float16 | float32

inference:
  num_steps: 128
  block_length: 32               # max_new_tokens must be divisible
  temperature: 0.0               # 0 = greedy argmax
  cfg_scale: 0.0                 # 0 = disabled
  remasking: "low_confidence"    # low_confidence | random
  max_new_tokens: 128

benchmarks:
  benchmarks: ["mbpp", "humaneval"]
  num_samples: null              # null = full dataset
  run_tests: true                # false = skip code execution
  verbose_errors: false          # --verbose_errors to enable

dags:
  dags: ["confidence"]
  cot_steps: 4

output:
  output_dir: "results"
  resume: false

5. Save per-sample outputs (QA pairs, ground truth, trajectory)

Add --save_outputs to any run to write per-sample files alongside the summary JSON:

# Default: writes both JSON and Excel
bash scripts/run_eval.sh --save_outputs

# Use the dedicated script (has comments explaining every option)
bash scripts/runs/save_outputs.sh --benchmarks mbpp --num_samples 50

# Also record unmasking trajectory (one entry per diffusion step per sample)
bash scripts/runs/save_outputs.sh --record_trajectory --num_samples 10

Output files written to results/<run>/:

File Contents
{bench}_{dag}_samples.json Full per-sample records: prompt, generated, ground truth, pass/fail
{bench}_{dag}_samples.xlsx Same data as a spreadsheet (one row per sample)
{bench}_{dag}_trajectory.json (only with --record_trajectory) Decoded token states at each diffusion step

Control what is included:

--save_outputs             # master switch (required)
--no_save_qa               # omit prompt + generated answer
--no_save_ground_truth     # omit reference answers
--record_trajectory        # add per-step unmasking states (large; keep off for big runs)
--output_formats json      # write only JSON (skip Excel)
--output_formats xlsx      # write only Excel (skip JSON)

Config file equivalents (configs/eval_default.yaml):

save:
  save_outputs: false       # master switch
  save_qa: true
  save_ground_truth: true
  record_trajectory: false
  output_formats: ["json", "xlsx"]

6. Single-prompt inference

python scripts/infer_llada.py \
    --model_id checkpoints/llada-instruct \
    --prompt "What is 7 * 8?" \
    --num_steps 128 --block_length 32 --temperature 0.0

Available strategies

Strategy Description
confidence Unmask highest-confidence tokens first
random Uniform random unmasking (no DAG constraint)
linear Left-to-right sequential
cot Chain-of-Thought segment DAG
skeleton Structural tokens first, then detail
bidirectional Both ends toward center
answer_first Answer region unmasked before reasoning

Available benchmarks

Benchmark Type Metric
mbpp Python code generation pass@1
humaneval Python code generation pass@1
gsm8k Math reasoning exact match
math Competition math exact match
mmlu Knowledge (multi-subject) accuracy
hotpotqa Multi-hop QA EM / F1
arc Science reasoning accuracy
prontoqa Logic reasoning accuracy

Project Structure

src/dllm_reason/
  models/          MDLM, SEDD, D3PM, LLaDA (4 dLLMs)
  graph/           TokenDAG, 6 templates, constraints, visualization
  scheduler/       Random, Confidence, Linear, DAGScheduler, Adaptive (5 schedulers)
  search/          Evolutionary, Greedy, RL Policy, NOTEARS (4 search methods)
  inference/       DiffusionSampler, DAGSampler
  training/        Pretrain, DAG-aware, Fine-tune, Diffu-GRPO
  eval/            Metrics, 4 benchmark evaluators, DAG analysis
  library/         DAG Library (store, retrieval, fusion, feedback, merge)
  data/            GSM8K, MATH, ARC, ProntoQA loaders
  utils/           Registry, logging, distributed

configs/           31 YAML configs (model, graph, search, task, eval, experiment, library)
scripts/           Train, evaluate, search, visualize, download, server setup
tests/             DAG, schedulers, models, library (4 test suites)
notebooks/         DAG exploration, results analysis
docs/              V1.0 release notes, API reference, presentation

Key Components

TokenDAG

The core data structure. A boolean adjacency matrix on GPU where edge (i, j) means "position i must unmask before position j".

from dllm_reason.graph.dag import TokenDAG

dag = TokenDAG.linear_chain(seq_len=256)
ready = dag.ready_positions(is_unmasked)  # one batched GPU op

6 templates: Chain-of-Thought, Answer-First, Skeleton-Detail, Bidirectional, Interleaved, Random.

DAGScheduler

Injects DAG constraints at the scheduler layer -- models need zero modification.

from dllm_reason.scheduler.dag_scheduler import DAGScheduler

scheduler = DAGScheduler(dag, sub_strategy="confidence_topk")
# sub_strategies: all_ready, confidence_topk, proportional

DAG Search

Automatically discover optimal DAG structures.

from dllm_reason.search.evolutionary import EvolutionarySearch

searcher = EvolutionarySearch(
    population_size=20,
    library=dag_store,           # seed from library
    task_description="math",
)
result = searcher.search(model, eval_fn, seq_len=256, budget=200)
# result.best_dag auto-written back to library

4 methods: Evolutionary, Greedy, RL Policy, Differentiable (NOTEARS).

DAG Library

Persistent storage + retrieval + feedback for DAG structures.

  • Store: SQLite + FAISS vector index
  • Retrieval: 3 channels (semantic, structural, performance) -- independently toggleable
  • Fusion: 4 strategies (weighted, RRF, max, voting)
  • Feedback: 3 sources (auto benchmark, human rating, Elo tournament)
  • Merge: 3 strategies (union, intersection, weighted)

All components independently toggleable for ablation experiments. 7 preset configs in configs/library/.

Models

Model Type Reference
MDLM Absorbing-state continuous-time diffusion Sahoo et al., 2024
SEDD Score-entropy discrete diffusion Lou et al., 2024
D3PM Discrete-time structured transitions Austin et al., 2021
LLaDA LLaMA-3 based masked diffusion (8B) GSAI-ML

Benchmarks

Benchmark Type Metric
GSM8K Math reasoning Exact match
MATH Competition math Exact match
MBPP Code generation pass@1
HumanEval Code generation pass@1
HotpotQA Multi-hop QA EM, F1
MMLU Knowledge Accuracy
ARC Science reasoning Accuracy
ProntoQA Logic Accuracy

Configuration

All configs use YAML + Hydra/OmegaConf.

Directory Contents
configs/model/ Model hyperparameters (mdlm, sedd, d3pm, llada)
configs/graph/ DAG template parameters
configs/search/ Search algorithm settings
configs/task/ Dataset configs
configs/eval/ Benchmark settings
configs/experiment/ End-to-end experiment combinations
configs/library/ DAG Library ablation variants
configs/eval_default.yaml Default evaluation config (used by run_eval.sh)

Documentation

License

MIT License. See LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

dllm_reason-1.2.4.tar.gz (87.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

dllm_reason-1.2.4-py3-none-any.whl (103.3 kB view details)

Uploaded Python 3

File details

Details for the file dllm_reason-1.2.4.tar.gz.

File metadata

  • Download URL: dllm_reason-1.2.4.tar.gz
  • Upload date:
  • Size: 87.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for dllm_reason-1.2.4.tar.gz
Algorithm Hash digest
SHA256 26863bd7d7aa97b7582bc355f03b6c1c3eadff4d91622dce9c55821590041c5f
MD5 fcbf67ef165c7b623b8a20b74d2630ba
BLAKE2b-256 e1214086976ae4d5fc7f3bfe79d1b427632f9444bce1bfec658cb7c6f2bc6d92

See more details on using hashes here.

Provenance

The following attestation bundles were made for dllm_reason-1.2.4.tar.gz:

Publisher: publish.yml on BDeMo/dLLM_Reason

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file dllm_reason-1.2.4-py3-none-any.whl.

File metadata

  • Download URL: dllm_reason-1.2.4-py3-none-any.whl
  • Upload date:
  • Size: 103.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for dllm_reason-1.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 43dcae046a9366ef86b5334981e20d0e3b14c33bc4fbd86c24d31f396bc71e81
MD5 6f38f014c07257ee489ca6ace3783075
BLAKE2b-256 30020351f6e57d4ac0035eb45015db07c6e31dabc24dddb632b38c7ac9b4fe51

See more details on using hashes here.

Provenance

The following attestation bundles were made for dllm_reason-1.2.4-py3-none-any.whl:

Publisher: publish.yml on BDeMo/dLLM_Reason

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page