Skip to main content

A package to route chat requests between LLMs based on prompt classification

Project description

LLM Predictive Router Package

This package allows you to route chat requests between small and large LLM models based on prompt classification. It dynamically selects the most suitable model depending on the complexity of the user input, ensuring optimal performance and maintaining conversation context.

Installation

You can install the package using pip:

pip install llm-predictive-router

Example Usage

from llm_predictive_router import LLMRouter

# Define model configuration
config = {
    "classifier": {
        "model_id": "DevQuasar/roberta-prompt_classifier-v0.1"
    },
    # The entity name should match the predicted label from your prompt classifier
    "small_llm": {
        "escalation_order": 0,
        "url": "http://localhost:1234/v1",
        "api_key": "lm-studio",
        "model_id": "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
        "max_ctx": 4096
    },
    # The entity name should match the predicted label from your prompt classifier
    "large_llm": {
        "escalation_order": 1,
        "url": "http://localhost:1234/v1",
        "api_key": "lm-studio",
        "model_id": "lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF/Meta-Llama-3-70B-Instruct-Q4_K_M.gguf",
        "max_ctx": 8192
    }
}

router = LLMRouter(config)

# Example call with simple prompt -> router to "small_llm"
response, context, selected_model = router.chat(
    "Hello", 
    temperature=0.5,   # Lower temperature for more focused responses
    max_tokens=100,    # Limit the response length
    verbose=True
)

# Another simple prompt -> router to "small_llm"
response, context, selected_model = router.chat(
    "Tell me a story about a cat",
    curr_ctx=context,  # carry the chat history
    model_store_entry=selected_model,
    temperature=0.5,   # Lower temperature for more focused responses
    max_tokens=512,    # Limit the response length
    verbose=True
)

# Default prompt classifier still considers this to a generic simple prompt -> router to "small_llm"
response, context, selected_model = router.chat(
    "Now explain the biology of the cat",
    curr_ctx=context,
    model_store_entry=selected_model,
    temperature=0.5,   # Lower temperature for more focused responses
    max_tokens=512,    # Limit the response length
    verbose=True
)

# This will escalate the model -> router to "large_llm" as we are getting into specific domain details
response, context, selected_model = router.chat(
    "Get into the details of his metabolism, especially interested in the detailed role of the liver",
    curr_ctx=context,
    model_store_entry=selected_model,
    temperature=0.5,   # Lower temperature for more focused responses
    max_tokens=512,    # Limit the response length
    verbose=True
)

Model Store JSON

The model store JSON defines the configuration of both the small and large LLM models that the router will switch between based on the prompt classification. Additionally, it includes a special classifier model entry used to predict the complexity of the user's prompt.

Example Model Store Structure:

{
  "classifier": {
    "model_id": "DevQuasar/roberta-prompt_classifier-v0.1"
  },
  "small_llm": {
    "escalation_order": 0,
    "url": "http://localhost:1234/v1",
    "api_key": "lm-studio",
    "model_id": "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
    "max_ctx": 4096
  },
  "large_llm": {
    "escalation_order": 1,
    "url": "http://localhost:1234/v1",
    "api_key": "lm-studio",
    "model_id": "lmstudio-community/Meta-Llama-3-70B-Instruct-GGUF/Meta-Llama-3-70B-Instruct-Q4_K_M.gguf",
    "max_ctx": 8192
  }
}

Explanation of Fields:

Classifier Entry:

  • classifier: This special entry defines the model used to classify the complexity of the user's input. The model_id field specifies the model that is fine-tuned for prompt classification.
    • model_id: The identifier of the classifier model (e.g., Roberta-based classifier). This model predicts the complexity of the user prompt, allowing the router to choose the appropriate LLM for the response. It does not generate text but informs the routing logic.

LLM Entries:

  • escalation_order: Defines the order in which models are escalated. Lower values are selected for less complex prompts, while higher values indicate more complex prompts.
  • url: The URL of the API endpoint where the model is hosted.
  • api_key: The API key required to authenticate with the model service.
  • model_id: The specific model identifier (for example, from a model hub like Hugging Face or a local deployment).
  • max_ctx: Maximum context size (in tokens) the model can handle.

router.chat Method Documentation

Overview

The chat method is responsible for handling user input, selecting the appropriate model based on the prompt classification, and managing the conversation context. It interacts with the model API to generate responses.

Inputs

  • user_prompt (str): The text prompt provided by the user for the model to respond to.
  • model_store_entry (dict, optional): An entry from the model store representing the current model. If None, the function will classify the prompt and select the initial model.
  • curr_ctx (list, optional): The current conversation context, a list of message objects between the user and the assistant.
  • system_prompt (str, optional): A system prompt or directive that provides additional instructions for the model (default is an empty string).
  • temperature (float, optional): Controls the randomness of the model’s output. A lower value (e.g., 0.1) makes responses more deterministic, while a higher value (e.g., 0.9) produces more random and creative outputs (default is 0.7).
  • max_tokens (int, optional): The maximum number of tokens to generate in the response (default: no explicit limit).
  • verbose (bool, optional): If True, the function will print additional debugging information, such as the selected model and generated completion (default is False).

Outputs

The method returns a tuple of three items:

  1. completion (str): The text completion generated by the model.
  2. messages (list): The updated conversation context, including the new user prompt and model response.
  3. model_store_entry (dict): The model entry that was used for generating the response. This can be passed back into subsequent calls to ensure the same model is used unless escalation is needed.

Example Usage

# Start with a simple prompt
response, context, selected_model = router.chat(
    user_prompt="hello", 
    verbose=True
)

# Escalate to a more complex prompt
p = "Discuss the challenges and potential solutions for achieving sustainable development in the context of increasing global urbanization."
response, context, selected_model = router.chat(
    user_prompt=p, 
    curr_ctx=context, 
    model_store_entry=selected_model,
    temperature=0.5,
    max_tokens=200,
    verbose=True
)

Solution Overview

The llm-predictive-router solution intelligently routes chat requests between models of different sizes based on the complexity of the prompt. By leveraging a pre-trained prompt classifier, the router can classify user inputs and escalate the model used for generating responses as needed.

Key components of the solution include:

  • Model Store: Defines the configuration of multiple LLM models, including small and large variants.
  • Prompt Classifier: A fine-tuned model (e.g., Roberta) that classifies user prompts to determine complexity.
  • Router: Responsible for selecting and switching between models dynamically, depending on the classification result and current context.
  • Chat Handling: Manages the conversation, tracks context, and interacts with the models to generate coherent responses.

This approach provides a balanced trade-off between model performance and response quality, allowing for optimal resource usage while maintaining high-quality conversational outputs.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

llm-predictive-router-0.2.1.tar.gz (5.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

llm_predictive_router-0.2.1-py3-none-any.whl (5.2 kB view details)

Uploaded Python 3

File details

Details for the file llm-predictive-router-0.2.1.tar.gz.

File metadata

  • Download URL: llm-predictive-router-0.2.1.tar.gz
  • Upload date:
  • Size: 5.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.8

File hashes

Hashes for llm-predictive-router-0.2.1.tar.gz
Algorithm Hash digest
SHA256 d55f8931bd9518aaeb80615c1b66a3c5efd158e75b07bea6a4547f1086eb59e9
MD5 b59438bb6fe6a0ddfefea19480652f07
BLAKE2b-256 75f4605b87fc77a7f051362258658f1e14cc6517a334dc6ab9dd5abdddb9e37c

See more details on using hashes here.

File details

Details for the file llm_predictive_router-0.2.1-py3-none-any.whl.

File metadata

File hashes

Hashes for llm_predictive_router-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 a0082c34954a3d95a5d6d491866d26f95c81cb362b3523fa9aa8f152cf37882a
MD5 d1539c0cc0d6cbea518fa66d3185b5ca
BLAKE2b-256 a548fcea94af75841e2e3b0d1e817362a4c17aa2308bff98b0dca6b357de219e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page