Skip to main content

An experimental project using Monte Carlo Tree Search (MCTS) to refine LLM responses for better accuracy and decision-making.

Project description

LLM MCTS Inference

Tests Publish PyPI

An experimental project using Monte Carlo Tree Search (MCTS) to refine Language Model (LLM) responses for better accuracy and decision-making.

Overview

This project leverages MCTS to explore multiple answer candidates generated by an LLM. By iteratively generating an initial answer, evaluating it, and refining it based on targeted self-feedback, the system strives to improve response quality and decision-making. This approach leverages test-time compute to enhance the precision and robustness of model outputs.

MCTS Inference Process

The process follows these key steps:

  • Initial Answer Generation: Uses greedy decoding to generate an initial response.
  • Feedback Generation: Provides constructive, concise feedback on initial answers. The feedback is generated by the model itself.
  • Iterative Refinement: Refines responses based on the feedback through additional model queries.
  • Monte Carlo Tree Search: Employs MCTS to explore and evaluate multiple answer paths.

Experimental Results

The performance of this approach was evaluated on a subset of the GSM8k test split using the Llama3.2-1B-instruct model with vLLM. A baseline run using zero-shot prompting achieved a pass@8 score of 74% and a majority@8 score of 27%. When applying MCTS for iterative refinement, the pass@8 score marginally increased to 75%, while the majority@8 score improved significantly to 39%. The evaluation was done with llm-eval.

These results suggest that while MCTS does not drastically improve the probability of generating at least one correct answer (pass@8), it significantly enhances response consistency (majority@8), making the model more reliable in decision-making scenarios.

Why Llama3.2-1B-instruct?

A smaller model was selected for this experiment to better illustrate the impact of MCTS. Larger models already achieve high accuracy on GSM8k, making it difficult to demonstrate meaningful improvements. The 1B parameter model provides a more realistic proof-of-concept by: • Being resource-efficient, allowing for scalable experimentation. • Providing a challenging test case, as smaller models struggle more with GSM8k, making improvements more noticeable. • Ensuring the evaluation remains relevant, since GSM8k has been extensively benchmarked by larger models, leaving little room for additional gains.

Installation

Dependencies

  • Python: Version 3.11 or higher

The project depends mainly on the following packages:

  • instructor for guided generation
  • litellm provides a unified API to interact with multiple LLM providers

Setup Instructions

To install the package directly from PyPi run the following command: pip install llm-mcts-inference

To install from source, follow these steps:

  1. Clone the Repository:

    git clone https://github.com/brotSchimmelt/llm-mcts-inference.git
    cd llm-mcts-inference
    
  2. Install the Project Dependencies:

    If you use uv, run the following commands to create a virtualenv and install all requirements:

    uv venv --python 3.11
    uv sync
    

    Otherwise, install the required packages with pip:

    pip install -r pyproject.toml
    
  3. Configure Environment Variables: Rename the provided example.env file to .env and update it with your API keys or other configuration details as needed.

Usage

Use the MonteCarloLLM class to generate and improve responses via MCTS:

from llm_mcts_inference.MonteCarloLLM import MonteCarloLLM

# Initialize with a specific model; defaults are defined in settings
llm = MonteCarloLLM(model_name="openai/gpt-4o-mini")

# Define your prompt
prompt = "What is the capital of France?"

# Generate a response using Monte Carlo Tree Search
result = llm.generate(prompt=prompt, iterations=5, max_children=3)

# Output the final improved answer
print("Final Answer:", result.answer)

# Optionally, display the sequence of nodes (answers) along the best path
print("Best Path:", [node.answer for node in result.valid_path])

License

This project is licensed under the MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llm_mcts_inference-0.1.4.tar.gz (12.7 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

llm_mcts_inference-0.1.4-py3-none-any.whl (12.9 kB view details)

Uploaded Python 3

File details

Details for the file llm_mcts_inference-0.1.4.tar.gz.

File metadata

  • Download URL: llm_mcts_inference-0.1.4.tar.gz
  • Upload date:
  • Size: 12.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: uv/0.5.31

File hashes

Hashes for llm_mcts_inference-0.1.4.tar.gz
Algorithm Hash digest
SHA256 3ff98ff53966d11ecdf852aedbe4605684578469f173c92636de4f2005f5f194
MD5 bf461ee077d60455d71490bce7be1c57
BLAKE2b-256 ed8ef898bec059c2156024b449ce6faea7e87385a54a443fb3ad8fe613b7070d

See more details on using hashes here.

File details

Details for the file llm_mcts_inference-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for llm_mcts_inference-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 9bd8f1917adbc799ad47d757e560a2279678137f6d5807579cf8b5d711c6eb0c
MD5 2b363b9ca7aa79c0f4ea98cb9641e0f6
BLAKE2b-256 b98cbed22b280c9ea83bbffc4108b857aef5648576a903afb152d2afbefc0ba5

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page