Skip to main content

A Tree Search Library with Flexible API for LLM Inference-Time Scaling

Project description

TreeQuest

Python GitHub license Checks status Tests status

arXiv Blog

AB-MCTS

A flexible answer tree search library featuring AB-MCTS, useful for (but not limited to) LLM inference-time scaling.

Quick Start

import random

import treequest as tq

# Each node is associated with a user-definable `state`.
State = str

# 1. Define a function to be used for node generation.
def generate(parent_state: State | None) -> tuple[State, float]:
    """Generates new states and scores based on the parent state."""
    if parent_state is None: # None represents the expansion from root.
        new_state = "Initial state"
    else:
        new_state = f"State after {parent_state}"

    score = random.random() # A score for the new state; It should be normalized to the [0, 1] range.
    return new_state, score

# 2. Instantiate the algorithm and a search tree object.
algo = tq.ABMCTSA()
search_tree = algo.init_tree()

# 3. Run the search with a generation budget (10 in this case).
for _ in range(10):
    search_tree = algo.step(search_tree, {'Action A': generate})

# 4. Extract the best score and state.
best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best state: {best_state}, Score: {best_node_score}")

Alternatively, you can use an ask–tell interface with batched AB-MCTS sampling steps:

import random
import treequest as tq

State = str

def generate(parent_state: State | None) -> tuple[State, float]:
    ...

generate_fns = {"Action A": generate}
actions = list(generate_fns.keys())

# We use batch_size=5 here
batch_size = 5

# It runs AB-MCTS sampling step with 5 processes in parallel
algo = tq.ABMCTSM(max_process_workers=batch_size)
search_tree = algo.init_tree()

total_budget = 50
num_steps = total_budget // batch_size
for _ in range(num_steps):
    # ask_batch returns a list of `Trial` object, which has action, parent_state and trial_id attrs
    search_tree, trials = algo.ask_batch(search_tree, batch_size, actions)

    for trial in trials:
        result = generate_fns[trial.action](trial.parent_state)
        # Call tell method with trial_id to update search_tree
        search_tree = algo.tell(search_tree, trial.trial_id, result)

best_state, best_node_score = tq.top_k(search_tree, algo, k=1)[0]

In particular for AB-MCTS-M, each step call can be slow. If you encounter slow execution, prefer ask_batch over step. Please note that using a large batch_size can skew the search-tree shape (i.e., the tree may become too wide), so it is best to avoid overly large batch_size, see PROFILING.md for example trees. We recommend batch_size<=5 as a starting point.

Features

  • Easy-to-use API with customizable node generation and node scoring logic.
  • AB-MCTS-A and AB-MCTS-M, as well as Multi-LLM AB-MCTS support (See our paper for algorithm details).
  • Checkpointing and resuming searches.

Installation

uv

First, install uv. Then you can install TreeQuest with the following command:

uv add "treequest[abmcts-m]"

pip

Alternatively, you can use pip to install TreeQuest:

pip install "treequest[abmcts-m]"

Usage

Using an LLM as a Node Generator

You can use any object as a node state. You only need to define a generating function that returns a (state, score) tuple and takes the parent state as an argument:

import dataclasses

import treequest as tq

@dataclasses.dataclass
class State:
    llm_answer: str
    score: float

def generate(parent_state: State | None) -> tuple[State, float]:
    """Generate a new node by calling an LLM."""
    if parent_state is None:
        state = initial_generation()
    else:
        state = refine_answer(parent_state.llm_answer, parent_state.score)

    return state, state.score
    
def initial_generation() -> State:
    """
    Call LLM API to generate an initial answer.
    """
    ...

def refine_answer(llm_answer: str, score: float) -> State:
    """
    Call LLM API to refine an answer.
    """
    ...


algo = tq.ABMCTSM()
search_tree = algo.init_tree()
for i in range(20):
    search_tree = algo.step(search_tree, {'Action Label': generate})
    # Logging best node during the search.
    if (i + 1) % 5 == 0:
        best_interim_state, _ = tq.top_k(search_tree, algo, k=1)[0]
        print(f"Iteration {i+1}: Best state so far = {best_interim_state}")

best_state, _ = tq.top_k(search_tree, algo, k=1)[0]
print(f"Best Answer: {best_state.llm_answer}, Best Score: {best_state.score}")

Using Multiple LLMs (and Beyond)

TreeQuest supports multiple action types. For example, you can provide multiple generation functions backed by different LLMs to represent different action types:

from functools import partial

import treequest as tq

def generate(llm_name: str, parent_state=None):
    """
    Call LLM API using litellm, vllm, etc., to generate a new node
    """
    ...
    return new_state, new_score

llm_names = ["o4-mini", "gemini-2.5-pro"]
# Create dict of different actions backed by different LLMs.
generate_fns = {llm_name: partial(generate, llm_name=llm_name) for llm_name in llm_names}

algo = tq.StandardMCTS()
search_tree = algo.init_tree()
for _ in range(20):
    search_tree = algo.step(search_tree, generate_fns)

The variation is not limited to LLM types; you can use different prompts, actions, scoring logic, etc. in generate_fns.

Batch Semantics and Concurrency

  • Algorithms are stateless objects; the evolving tree/search state is returned from init_tree, step, ask, and tell.
  • ask_batch(state, batch_size, actions) returns exactly batch_size Trial objects to expand next.
    • Non-queue algorithms (e.g., ABMCTSM, ABMCTSA, MultiArmedBanditUCB) return exactly batch_size Trials.
    • Queue-based algorithms (e.g., StandardMCTS, BestFirstSearchAlgo, TreeOfThoughtsBFS) precompute a set of parent/action pairs and duplicate them if needed to fill batch_size.
  • tell(state, trial_id, (new_state, score)) reflects the result for the corresponding Trial.
    • Order-independent: you can call tell in any order; reflection is tied to trial_id.
    • Idempotent: calling tell twice on the same trial_id does not add extra nodes.
    • For queue-based algorithms, over-told Trials beyond possible number of childs from a parent node (e.g., (# actions)*samples_per_action for StandardMCTS) become INVALID and are not reflected.
  • Scores are expected to be normalized to the [0, 1] range.

Algorithms

ABMCTS-A: ABMCTS with Node Aggregation

ABMCTS-A uses node aggregation for adaptive branching:

import treequest as tq

# Instantiate the ABMCTS-A algorithm.
ab_mcts_a = tq.ABMCTSA()

search_tree = ab_mcts_a.init_tree()
for _ in range(50):
    search_tree = ab_mcts_a.step(search_tree, generate_fns)

ABMCTS-M: ABMCTS with Mixed Models

ABMCTS-M leverages PyMC's mixed modeling capabilities:

import treequest as tq

# Instantiate the ABMCTS-M algorithm.
ab_mcts_m = tq.ABMCTSM()

search_tree = ab_mcts_m.init_tree()
for _ in range(30):
    search_tree = ab_mcts_m.step(search_tree, generate_fns)

NOTE: To run AB-MCTS-M, you need to install extra dependencies with the treequest[abmcts-m] option.

Requirements

  • Python 3.11+

Contributing

Contributions are welcome! Please see CONTRIBUTING.md for development tips.

Citation

@article{inoue2025wider,
  title={Wider or Deeper? Scaling LLM Inference-Time Compute with Adaptive Branching Tree Search},
  author={Inoue, Yuichi and Misaki, Kou and Imajuku, Yuki and Kuroki, So and Nakamura, Taishi and Akiba, Takuya},
  journal={arXiv preprint arXiv:2503.04412},
  year={2025}
}

License

Apache 2.0

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

treequest-0.2.0.tar.gz (809.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

treequest-0.2.0-py3-none-any.whl (47.4 kB view details)

Uploaded Python 3

File details

Details for the file treequest-0.2.0.tar.gz.

File metadata

  • Download URL: treequest-0.2.0.tar.gz
  • Upload date:
  • Size: 809.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.3

File hashes

Hashes for treequest-0.2.0.tar.gz
Algorithm Hash digest
SHA256 3cc4eead2a1bfee6c47407979321fbf8a7982629b501f29ab996673d6d71725b
MD5 21e6d969e3dabe8cf49ae3300a76e43f
BLAKE2b-256 9a1cbc316703ca9721e8f228c294d104653b518d19395b80f6988f89046c2d8b

See more details on using hashes here.

File details

Details for the file treequest-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: treequest-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 47.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.9.3

File hashes

Hashes for treequest-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 d02fada7cb8da5fb0f97f1dc1389762fe4dbc2341d27ba5490b24ac4fd5c8b9a
MD5 e0564b74e4256279cb70c79f45a4b97f
BLAKE2b-256 fdd7b10fe46e702d40174d21219746b9a994c6f153e70518267df36737b067f6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page