A Tree Search Library with Flexible API for LLM Inference-Time Scaling
Project description
TreeQuest
A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.
Quick Start
import random
import treequest as tq
# Each node is associated with a user-definable `state`.
State = str
# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generates new states and scores based on the parent state."""
if parent_state is None: # None represents the expansion from root.
new_state = "Initial state"
else:
new_state = f"State after {parent_state}"
score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
return new_state, score
# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()
# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
search_tree = algo.step(search_tree, {'Action A': generate})
# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")
Alternatively, you can use an ask–tell interface with batched AB-MCTS sampling steps:
import random
import treequest as tq
State = str
def generate(parent_state: State | None) -> tuple[State, float]:
...
generate_fns = {"Action A": generate}
actions = list(generate_fns.keys())
# We use batch_size=5 here
batch_size = 5
# It runs AB-MCTS sampling step with 5 processes in parallel
algo = tq.ABMCTSM(max_process_workers=batch_size)
search_tree = algo.init_tree()
total_budget = 50
num_steps = total_budget // batch_size
for _ in range(num_steps):
# ask_batch returns a list of `Trial` object, which has action, parent_state and trial_id attrs
search_tree, trials = algo.ask_batch(search_tree, batch_size, actions)
for trial in trials:
result = generate_fns[trial.action](trial.parent_state)
# Call tell method with trial_id to update search_tree
search_tree = algo.tell(search_tree, trial.trial_id, result)
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
In particular for AB-MCTS-M, each step call can be slow. If you encounter slow execution, prefer ask_batch over step.
Please note that using a large batch_size can skew the search-tree shape (i.e., the tree may become too wide), so it is best to avoid overly large batch_size, see PROFILING.md for example trees.
We recommend batch_size<=5 as a starting point.
Features
- Easy-to-use API with customizable node generation and node scoring logic.
- AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
- Checkpointing and resuming searches.
Installation
uv
First, install uv. Then you can install TreeQuest with the following command:
uv add "treequest[abmcts-m]"
pip
Alternatively, you can use pip to install TreeQuest:
pip install "treequest[abmcts-m]"
Usage
Using an LLM as a Node Generator
You can use any object as a node state. You only need to define a generating function that returns a (state, score) tuple and takes the parent state as an argument:
import dataclasses
import treequest as tq
@dataclasses.dataclass
class State:
llm_answer: str
score: float
def generate(parent_state: State | None) -> tuple[State, float]:
"""Generate a new node by calling an LLM."""
if parent_state is None:
state = initial_generation()
else:
state = refine_answer(parent_state.llm_answer, parent_state.score)
return state, state.score
def initial_generation() -> State:
"""
Call LLM API to generate an initial answer.
"""
...
def refine_answer(llm_answer: str, score: float) -> State:
"""
Call LLM API to refine an answer.
"""
...
algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
search_tree = algo.step(search_tree, {'Action Label': generate})
# Logging best node during the search.
if (i + 1) % 5 == 0:
best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Iteration {i+1}: Best state so far = {best_interim_state}")
best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")
Using Multiple LLMs (and Beyond)
TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:
from functools import partial
import treequest as tq
def generate(llm_name: str, parent_state=None):
"""
Call LLM API using litellm, vllm, etc., to generate a new node
"""
...
return new_state, new_score
llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}
algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
search_tree = algo.step(search_tree, generate_fns)
The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns.
Batch Semantics and Concurrency
- Algorithms are stateless objects; the evolving tree/search state is returned from
init_tree,step,ask, andtell. ask_batch(state, batch_size, actions)returns exactlybatch_sizeTrial objects to expand next.- Non-queue algorithms (e.g.,
ABMCTSM,ABMCTSA,MultiArmedBanditUCB) return exactlybatch_sizeTrials. - Queue-based algorithms (e.g.,
StandardMCTS,BestFirstSearchAlgo,TreeOfThoughtsBFS) precompute a set of parent/action pairs and duplicate them if needed to fillbatch_size.
- Non-queue algorithms (e.g.,
tell(state, trial_id, (new_state, score))reflects the result for the corresponding Trial.- Order-independent: you can call
tellin any order; reflection is tied totrial_id. - Idempotent: calling
telltwice on the sametrial_iddoes not add extra nodes. - For queue-based algorithms, over-told Trials beyond possible number of childs from a parent node (e.g.,
(# actions)*samples_per_actionfor StandardMCTS) becomeINVALIDand are not reflected.
- Order-independent: you can call
- Scores are expected to be normalized to the
[0, 1]range.
Algorithms
ABMCTS-A: ABMCTS with Node Aggregation
ABMCTS-A uses node aggregation for adaptive branching:
import treequest as tq
# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()
search_tree = ab_mcts_a.init_tree()
for _ in range(50):
search_tree = ab_mcts_a.step(search_tree, generate_fns)
ABMCTS-M: ABMCTS with Mixed Models
ABMCTS-M leverages PyMC's mixed modeling capabilities:
import treequest as tq
# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()
search_tree = ab_mcts_m.init_tree()
for _ in range(30):
search_tree = ab_mcts_m.step(search_tree, generate_fns)
NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m] option.
Requirements
- Python 3.11+
Contributing
Contributions are welcome! Please see CONTRIBUTING.md for development tips.
Citation
@article{inoue2025wider,
title={Wider or Deeper? Scaling LLM Inference-Time Compute with Adaptive Branching Tree Search},
author={Inoue, Yuichi and Misaki, Kou and Imajuku, Yuki and Kuroki, So and Nakamura, Taishi and Akiba, Takuya},
journal={arXiv preprint arXiv:2503.04412},
year={2025}
}
License
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file treequest-0.2.0.tar.gz.
File metadata
- Download URL: treequest-0.2.0.tar.gz
- Upload date:
- Size: 809.2 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3cc4eead2a1bfee6c47407979321fbf8a7982629b501f29ab996673d6d71725b
|
|
| MD5 |
21e6d969e3dabe8cf49ae3300a76e43f
|
|
| BLAKE2b-256 |
9a1cbc316703ca9721e8f228c294d104653b518d19395b80f6988f89046c2d8b
|
File details
Details for the file treequest-0.2.0-py3-none-any.whl.
File metadata
- Download URL: treequest-0.2.0-py3-none-any.whl
- Upload date:
- Size: 47.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: uv/0.9.3
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d02fada7cb8da5fb0f97f1dc1389762fe4dbc2341d27ba5490b24ac4fd5c8b9a
|
|
| MD5 |
e0564b74e4256279cb70c79f45a4b97f
|
|
| BLAKE2b-256 |
fdd7b10fe46e702d40174d21219746b9a994c6f153e70518267df36737b067f6
|