Skip to main content

LangChain integration for Parallel Context-of-Experts Decoding (PCED) for RAG pipelines

Project description

langchain-pced

A LangChain integration for Parallel Context-of-Experts Decoding (PCED), a training-free framework that addresses the evidence aggregation challenge in Retrieval-Augmented Generation (RAG) by shifting from attention-based to decoding-based fusion.

Reference Paper: Parallel Context-of-Experts Decoding for Retrieval Augmented Generation

Method Overview

PCED Method Overview

Figure 1: Overview of the Parallel Context-of-Experts Decoding (PCED) framework.

PCED treats each retrieved document as an independent "expert" context, processing them in parallel rather than concatenating into a single sequence. At each decoding step, expert logits undergo contrastive transformation against an "amateur" (context-free) baseline, followed by retrieval-aware selection:

$$e_i^* = (1 + \beta) \cdot \text{logits}{e_i} - \beta \cdot \text{logits}{\text{amateur}}$$

$$\text{winner} = \arg\max_i \left[ \max(e_i^*) + \gamma \cdot s_i \right]$$

where $\beta$ is the contrastive strength (dynamically computed via JSD when unspecified), $\gamma$ weights the retrieval scores $s_i$, and the final token is sampled from the winning expert's distribution

Installation

pip install langchain-pced

Required: Flash Attention

Flash Attention is required for PCED to work. Install it before using the library:

pip install flash-attn --no-build-isolation

Dependencies

Package Version Purpose
transformers ≥ 5.0.0 PagedAttention support for efficient KV cache management
langchain-core ≥ 1.0.0 Chain abstraction and prompt templating
torch ≥ 2.0.0 Tensor operations and CUDA support

Quick Start

This library follows standard LangChain patterns, providing PCEDPromptTemplate for template construction and supporting the canonical prompt | chat | parser chain composition.

Basic Usage

from langchain_pced import PCEDLLM, PCEDChatModel, PCEDPromptTemplate, harmonic_mean_fusion
from langchain_core.output_parsers import StrOutputParser

# Initialize the PCED-enabled LLM
llm = PCEDLLM.from_model_id(
    model_id="meta-llama/Llama-3.1-8B-Instruct",
    device="cuda:0",
)
chat = PCEDChatModel(llm=llm)

# Define the prompt template with expert placeholder specification
prompt = PCEDPromptTemplate.from_messages(
    messages=[
        ("system", "Answer ONLY using the provided context."),
        ("human", "Context:\n{context}\n\nQuestion:\n{question}"),
    ],
    expert_placeholder_key="context",  # Placeholder populated per-expert
)

# Retrieval scores (from your retrieval pipeline)
retrieval_scores = [0.9, 0.8, 0.7]
reranker_scores = [0.95, 0.6, 0.85]
fused_scores = harmonic_mean_fusion(retrieval_scores, reranker_scores)

# Retrieved documents
documents = [
    "The giant panda is a bear species endemic to China.",
    "Paris is the capital of France.",
    "Python is a programming language.",
]

# Construct and invoke the chain
chain = prompt | chat | StrOutputParser()

result = chain.invoke({
    "question": "What is a panda?",
    "expert_documents": documents,
    "expert_scores": fused_scores,
})
print(result)

Optional: Answer Prefix

To constrain the model's response format, an optional answer_prefix can be specified:

prompt = PCEDPromptTemplate.from_messages(
    messages=[
        ("system", "Answer ONLY using the provided context."),
        ("human", "Context:\n{context}\n\nQuestion:\n{question}"),
    ],
    expert_placeholder_key="context",
    answer_prefix="Based on the provided context, ",
)

Comparative Analysis: PCED vs. Standard RAG

The following examples illustrate the key differences between standard RAG (context concatenation) and PCED (parallel expert decoding). Note that the retrieval and reranking stages are identical—only the generation methodology differs.

Standard RAG Implementation

Additional Dependencies:

pip install langchain-huggingface>=1.0.0 langchain-community>=0.4.0
from __future__ import annotations
from typing import List

from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

# === Retrieval Infrastructure ===
embeddings = HuggingFaceEmbeddings(
    model_name="BAAI/bge-m3",
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},
)
vectorstore = InMemoryVectorStore(embeddings)

docs = [
    Document(page_content="The giant panda is a bear species endemic to China."),
    Document(page_content="Paris is the capital of France."),
]
vectorstore.add_documents(docs)

cross_encoder = HuggingFaceCrossEncoder(
    model_name="BAAI/bge-reranker-v2-m3",
    model_kwargs={"device": "cuda"},
)

# === Retrieval and Reranking ===
def retrieve_and_rerank(query: str, top_k: int = 10, top_n: int = 3):
    pairs = vectorstore.similarity_search_with_score(query, k=top_k)
    docs, vec_scores, rerank_scores = [], [], []
    
    for d, score in pairs:
        docs.append(d)
        vec_scores.append(float(score))
    
    text_pairs = [(query, d.page_content) for d in docs]
    rerank_scores = list(cross_encoder.score(text_pairs))
    
    sorted_data = sorted(zip(docs, vec_scores, rerank_scores), 
                         key=lambda x: x[2], reverse=True)[:top_n]
    return ([d for d, _, _ in sorted_data], 
            [v for _, v, _ in sorted_data], 
            [r for _, _, r in sorted_data])

# === Standard RAG: Context Concatenation ===
base_llm = HuggingFacePipeline.from_model_id(
    model_id="Qwen/Qwen3-8B",
    task="text-generation",
    device_map="auto",
    pipeline_kwargs={"max_new_tokens": 512, "return_full_text": False},
)
chat_llm = ChatHuggingFace(llm=base_llm)

prompt = ChatPromptTemplate.from_messages([
    ("system", "Answer ONLY using the provided context."),
    ("human", "Context:\n{context}\n\nQuestion:\n{question}"),
])

chain = prompt | chat_llm | StrOutputParser()

question = "What is panda?"
top_docs, vec_scores, rerank_scores = retrieve_and_rerank(question)

# Concatenate all documents into single context
context = "\n\n".join(d.page_content for d in top_docs)
result = chain.invoke({"context": context, "question": question})
print(result)

PCED Implementation

from langchain_pced import PCEDLLM, PCEDChatModel, PCEDPromptTemplate, harmonic_mean_fusion
from langchain_core.output_parsers import StrOutputParser

# ... identical retrieval infrastructure as above ...

# === PCED: Parallel Expert Decoding ===
llm = PCEDLLM.from_model_id(
    model_id="Qwen/Qwen3-8B",
    device="cuda:0",
    max_new_tokens=512,
    enable_thinking=True
)
chat = PCEDChatModel(llm=llm)

prompt = PCEDPromptTemplate.from_messages(
    messages=[
        ("system", "Answer ONLY using the provided context."),
        ("human", "Context:\n{context}\n\nQuestion:\n{question}"),
    ],
    expert_placeholder_key="context",
)

question = "What is panda?"
top_docs, vec_scores, rerank_scores = retrieve_and_rerank(question)

# Fuse retrieval and reranker scores
fused_scores = harmonic_mean_fusion(vec_scores, rerank_scores)

chain = prompt | chat | StrOutputParser()

# Each document processed as independent expert
result = chain.invoke({
    "question": question,
    "expert_documents": [d.page_content for d in top_docs],
    "expert_scores": fused_scores,
})
print(result)

Methodological Comparison

Aspect Standard RAG PCED
Context Handling Documents concatenated into single sequence Each document processed independently
Attention Complexity $O((N \cdot L)^2)$ $O(N \cdot L^2)$
Score Utilization Retrieval ranking only Incorporated during decoding for expert weighting
Cross-Document Reasoning Implicit via attention Explicit via expert selection and score fusion
Memory Efficiency Single long KV cache Shared prefix with forked expert caches (PagedAttention)

Core Concepts

Expert Documents and Scores

PCED treats each retrieved document as an independent "expert" that processes the query in isolation. Two inputs are required:

  • expert_documents: A list of $N$ document strings
  • expert_scores: A list of $N$ relevance scores (same cardinality as documents)

Scores are mandatory because PCED uses them to weight expert contributions during the decoding process.

Score Fusion via Harmonic Mean

The harmonic_mean_fusion() function combines retrieval scores (e.g., dense vector similarity) with reranker scores (e.g., cross-encoder relevance):

from langchain_pced import harmonic_mean_fusion

vec_scores = [0.9, 0.8, 0.7]      # Dense retrieval scores
rerank_scores = [0.95, 0.6, 0.85]  # Cross-encoder scores

fused = harmonic_mean_fusion(vec_scores, rerank_scores)
# Result: [0.924, 0.686, 0.768]

The harmonic mean is employed due to its conservative aggregation property—it penalizes cases where one score is low even when the other is high, promoting consensus between retrieval and reranking signals.

Expert Placeholder Mechanism

The expert_placeholder_key parameter specifies which template variable is populated per-expert:

prompt = PCEDPromptTemplate.from_messages(
    messages=[
        ("system", "Answer using the provided context."),
        ("human", "Context: {context}\n\nQuestion: {question}"),
    ],
    expert_placeholder_key="context",
)

PCED internally constructs:

  • Amateur prompt: Template with {context} → empty string (no retrieval context)
  • Expert $i$ prompt: Template with {context} → document $d_i$

Automatic Context Truncation

Documents exceeding the model's context window are automatically truncated:

$$\text{max_doc_length} = \text{max_position_embeddings} - \text{max_new_tokens} - \text{prompt_overhead}$$

This prevents out-of-memory errors and ensures sufficient capacity for generation.

Streaming Support

PCED supports streaming generation for real-time applications:

for chunk in chain.stream({
    "question": "What is a panda?",
    "expert_documents": [d.page_content for d in top_docs],
    "expert_scores": fused_scores,
}):
    print(chunk, end="", flush=True)

API Reference

PCEDLLM Parameters

Parameter Type Default Description
model_id str required HuggingFace model identifier
beta float None Contrastive weight; None enables dynamic JSD-based computation
gamma float 2.5 Retrieval score weight for expert selection
max_new_tokens int 256 Maximum generation length
temperature float None Sampling temperature (None uses model default)
top_k int None Top-k sampling parameter
top_p float None Nucleus sampling parameter
do_sample bool None Enable stochastic sampling
enable_thinking bool False Enable thinking mode (Qwen3-style models)
device str "cuda:0" Computation device
attn_implementation str "paged|flash_attention_2" Attention implementation (requires flash-attn)

Note: When sampling parameters are None, the model's generation_config.json defaults are used. For example:

  • Llama-3.1: do_sample=True, temperature=0.6, top_p=0.9
  • Qwen3: do_sample=True, temperature=0.6, top_k=20, top_p=0.95

PCEDPromptTemplate.from_messages() Parameters

Parameter Type Default Description
messages List[tuple] required List of (role, content) tuples defining the prompt
expert_placeholder_key str "context" Template variable populated per-expert
answer_prefix str None Optional prefix prepended to model responses

chain.invoke() Parameters

Parameter Type Required Description
expert_documents List[str] Yes Document strings, one per expert
expert_scores List[float] Yes Relevance scores (use harmonic_mean_fusion)
<template_variables> Any Varies Additional template variables (e.g., question)

Citation

If you use this library in your research, please cite the original paper:

@article{corallo2026pced,
  title={Parallel Context-of-Experts Decoding for Retrieval Augmented Generation},
  author={Corallo, Giulio and Papotti, Paolo},
  journal={arXiv preprint arXiv:2601.08670},
  year={2026}
}

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

langchain_pced-0.2.0.tar.gz (134.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

langchain_pced-0.2.0-py3-none-any.whl (30.5 kB view details)

Uploaded Python 3

File details

Details for the file langchain_pced-0.2.0.tar.gz.

File metadata

  • Download URL: langchain_pced-0.2.0.tar.gz
  • Upload date:
  • Size: 134.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.6

File hashes

Hashes for langchain_pced-0.2.0.tar.gz
Algorithm Hash digest
SHA256 c50b71f5ee376c5d0be71dd8d619ac49d0e1cda80e6ef9e8ae78a5c416519756
MD5 4d4c3ec21df8beaffe3a410d448fc1bf
BLAKE2b-256 b4129d46e22dfc7203000bddeab022c7646caded8a4964c1d7881da57fcf62ea

See more details on using hashes here.

File details

Details for the file langchain_pced-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: langchain_pced-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 30.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.11.6

File hashes

Hashes for langchain_pced-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a9ae53575fe5617a4d743d04bd370eeaf04cd3d93d66a82b11b08b2459607d1c
MD5 6e5910af4d48aafe1a03d2e68be55ef4
BLAKE2b-256 57c226e65b4779cc6d5da03d1183f570da45267d7abe0555801df11f158f3d0f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page