Skip to main content

Simple middleware library for Pydantic-AI - before/after hooks without imposed guardrails structure

Project description

Pydantic AI Middleware

Intercept, Transform & Guard Every Pydantic AI Call

DocsExamplesPyPI

PyPI version Python 3.10+ License: MIT Coverage CI Pydantic AI

7 Lifecycle Hooks  •  Parallel Execution  •  Async Guardrails  •  Fully Type-Safe


Get Started in 60 Seconds

pip install pydantic-ai-middleware
from pydantic_ai import Agent
from pydantic_ai_middleware import MiddlewareAgent, AgentMiddleware, InputBlocked

class ContentFilter(AgentMiddleware[None]):
    """Block dangerous prompts before they reach the LLM."""

    async def before_run(self, prompt, deps, ctx=None):
        if "ignore all instructions" in prompt.lower():
            raise InputBlocked("Prompt injection attempt blocked")
        return prompt

    async def after_run(self, prompt, output, deps, ctx=None):
        # Redact sensitive patterns from the LLM response
        return output.replace("SSN:", "[REDACTED]")

# Wrap any pydantic-ai Agent with middleware
base_agent = Agent("openai:gpt-4o", instructions="You are a helpful assistant.")
agent = MiddlewareAgent(agent=base_agent, middleware=[ContentFilter()])

result = await agent.run("Hello, how are you?")
print(result.output)

That's it. Your pydantic-ai agent now has input validation, output filtering, and prompt injection protection — all with a simple wrapper.


Why Middleware, Not Guardrails?

pydantic-ai-middleware takes a different approach from traditional guardrails libraries:

Aspect Middleware (this library) Traditional Guardrails
Complexity Low High
Structure No imposed structure Fixed result types, actions
Flexibility Maximum Constrained by design
Learning curve Flat Steeper
Built-in guardrails None (you build what you need) Pre-built (PII, moderation)
Parallel execution Built-in with early cancellation Often built-in
Type safety Full generics support Varies

You decide what to build. Logging, guardrails, metrics, rate limiting, PII redaction — all using the same simple API.


Features

  • 7 Lifecycle Hooksbefore_run, after_run, before_model_request, before_tool_call, on_tool_error, after_tool_call, on_error
  • Parallel Execution — Run multiple middleware concurrently with 4 aggregation strategies and early cancellation
  • Async Guardrails — Run guardrails alongside LLM calls with BLOCKING, CONCURRENT, or ASYNC_POST timing
  • Middleware Chains — Compose middleware into reusable, ordered sequences with + operator
  • Conditional Routing — Route to different middleware based on runtime conditions
  • Config Loading — Build pipelines from JSON/YAML configuration files
  • Decorator Syntax — Create middleware from simple decorated functions
  • Context Sharing — Share data between hooks with access control
  • Tool Name Filtering — Scope middleware to specific tools with tool_names
  • Hook Timeouts — Per-middleware timeout enforcement with MiddlewareTimeout
  • Permission Decisions — Structured ALLOW/DENY/ASK protocol for tool authorization
  • Zero Overhead — No mandatory dependencies beyond pydantic-ai

Hook Lifecycle

Hook Lifecycle Diagram

Hook When Called Can Modify
before_run Before agent starts Prompt
after_run After agent finishes Output
before_model_request Before each model call Messages
before_tool_call Before tool execution Tool arguments
on_tool_error When a tool raises an exception Exception (replace or re-raise)
after_tool_call After tool execution Tool result
on_error When error occurs Exception

Real-World Examples

Input Validation + Rate Limiting

import time
from pydantic_ai import Agent
from pydantic_ai_middleware import (
    MiddlewareAgent,
    AgentMiddleware,
    InputBlocked,
    MiddlewareContext,
)

class RateLimiter(AgentMiddleware[None]):
    """Limit how many requests per minute."""

    def __init__(self, max_per_minute: int = 10):
        self.max_per_minute = max_per_minute
        self._timestamps: list[float] = []

    async def before_run(self, prompt, deps, ctx=None):
        now = time.time()
        self._timestamps = [t for t in self._timestamps if now - t < 60]
        if len(self._timestamps) >= self.max_per_minute:
            raise InputBlocked("Rate limit exceeded — try again later")
        self._timestamps.append(now)
        return prompt

class PromptSanitizer(AgentMiddleware[None]):
    """Remove potentially harmful instructions from prompts."""

    BLOCKED_PATTERNS = ["ignore previous", "system prompt", "jailbreak"]

    async def before_run(self, prompt, deps, ctx=None):
        lower = prompt.lower() if isinstance(prompt, str) else str(prompt).lower()
        for pattern in self.BLOCKED_PATTERNS:
            if pattern in lower:
                raise InputBlocked(f"Blocked pattern detected: {pattern}")
        return prompt

# Build the agent with middleware pipeline
base_agent = Agent("openai:gpt-4o", instructions="You are a customer support agent.")

agent = MiddlewareAgent(
    agent=base_agent,
    middleware=[RateLimiter(max_per_minute=20), PromptSanitizer()],
    context=MiddlewareContext(),  # enable cross-middleware data sharing
)

result = await agent.run("How do I reset my password?")

Tool Authorization with Permission Decisions

from pydantic_ai import Agent, RunContext
from pydantic_ai_middleware import (
    MiddlewareAgent,
    AgentMiddleware,
    ToolDecision,
    ToolPermissionResult,
)

# Define a pydantic-ai agent with tools
base_agent = Agent("openai:gpt-4o", instructions="You are a file manager.")

@base_agent.tool
async def read_file(ctx: RunContext[None], path: str) -> str:
    """Read a file from disk."""
    return f"Contents of {path}"

@base_agent.tool
async def delete_file(ctx: RunContext[None], path: str) -> str:
    """Delete a file from disk."""
    return f"Deleted {path}"

# Middleware that controls tool access
class FileAccessControl(AgentMiddleware[None]):
    """Require explicit approval for destructive file operations."""

    tool_names = {"delete_file"}  # only intercept delete_file

    async def before_tool_call(self, tool_name, tool_args, deps, ctx=None):
        return ToolPermissionResult(
            decision=ToolDecision.ASK,
            reason=f"Agent wants to delete: {tool_args.get('path')}",
        )

# Permission handler — called when middleware returns ASK
async def approval_callback(tool_name: str, tool_args: dict, reason: str) -> bool:
    print(f"[APPROVAL REQUIRED] {reason}")
    response = input("Allow? (y/n): ")
    return response.lower() == "y"

agent = MiddlewareAgent(
    agent=base_agent,
    middleware=[FileAccessControl()],
    permission_handler=approval_callback,
)

# read_file works without approval, delete_file triggers the callback
result = await agent.run("Read config.yaml then delete temp.log")

Structured Output with Audit Logging

from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai_middleware import (
    MiddlewareAgent,
    AgentMiddleware,
    MiddlewareContext,
    ScopedContext,
)

class SupportTicket(BaseModel):
    category: str
    priority: str
    summary: str

class AuditLogger(AgentMiddleware[None]):
    """Log all agent interactions for compliance."""

    async def before_run(self, prompt, deps, ctx=None):
        if ctx:
            ctx.set("input_prompt", prompt)
        print(f"[AUDIT] Input: {prompt[:80]}...")
        return prompt

    async def after_run(self, prompt, output, deps, ctx=None):
        print(f"[AUDIT] Output type: {type(output).__name__}")
        return output

    async def before_tool_call(self, tool_name, tool_args, deps, ctx=None):
        print(f"[AUDIT] Tool call: {tool_name}({tool_args})")
        return tool_args

# Agent with structured output — middleware works transparently
base_agent = Agent(
    "openai:gpt-4o",
    instructions="Classify support tickets.",
    output_type=SupportTicket,
)

agent = MiddlewareAgent(
    agent=base_agent,
    middleware=[AuditLogger()],
    context=MiddlewareContext(),
)

result = await agent.run("My payment failed and I can't access my account")
ticket: SupportTicket = result.output
print(f"Category: {ticket.category}, Priority: {ticket.priority}")

Parallel Validators with Async Guardrails

from pydantic_ai import Agent
from pydantic_ai_middleware import (
    MiddlewareAgent,
    AgentMiddleware,
    ParallelMiddleware,
    AsyncGuardrailMiddleware,
    AggregationStrategy,
    GuardrailTiming,
    InputBlocked,
)

class ProfanityFilter(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        # Simulated check — replace with real classifier
        if any(word in prompt.lower() for word in ["badword"]):
            raise InputBlocked("Profanity detected")
        return prompt

class PIIDetector(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        import re
        if re.search(r"\b\d{3}-\d{2}-\d{4}\b", prompt):
            raise InputBlocked("SSN detected in input")
        return prompt

class ToxicityChecker(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        # Simulated slow ML-based check
        import asyncio
        await asyncio.sleep(0.5)
        return prompt

# Run ProfanityFilter + PIIDetector in parallel (fast, both must pass)
fast_validators = ParallelMiddleware(
    middleware=[ProfanityFilter(), PIIDetector()],
    strategy=AggregationStrategy.ALL_MUST_PASS,
)

# Run ToxicityChecker concurrently with the LLM (saves latency)
toxicity_guard = AsyncGuardrailMiddleware(
    guardrail=ToxicityChecker(),
    timing=GuardrailTiming.CONCURRENT,
    cancel_on_failure=True,
)

base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(
    agent=base_agent,
    middleware=[fast_validators, toxicity_guard],
)

result = await agent.run("Summarize this document for me")

Decorator Syntax for Quick Middleware

from pydantic_ai import Agent
from pydantic_ai_middleware import (
    MiddlewareAgent,
    before_run,
    after_run,
    before_tool_call,
    on_tool_error,
)

@before_run
async def log_input(prompt, deps, ctx=None):
    print(f">>> {prompt}")
    return prompt

@after_run
async def log_output(prompt, output, deps, ctx=None):
    print(f"<<< {output}")
    return output

@before_tool_call(tools={"web_search"})
async def validate_search(tool_name, tool_args, deps, ctx=None):
    """Only runs for web_search tool, skipped for all others."""
    query = tool_args.get("query", "")
    if len(query) > 500:
        tool_args["query"] = query[:500]
    return tool_args

@on_tool_error(tools={"web_search"})
async def handle_search_error(tool_name, tool_args, error, deps, ctx=None):
    if isinstance(error, TimeoutError):
        return ConnectionError("Search service temporarily unavailable")
    return None  # re-raise original error

base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(
    agent=base_agent,
    middleware=[log_input, log_output, validate_search, handle_search_error],
)

result = await agent.run("Search for the latest Python release notes")

Middleware Chains + Conditional Routing

from pydantic_ai import Agent
from pydantic_ai_middleware import (
    MiddlewareAgent,
    MiddlewareChain,
    ConditionalMiddleware,
    AgentMiddleware,
)

class AuthMiddleware(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        if ctx:
            ctx.set("authenticated", True)
        return prompt

class AdminAudit(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        print("[ADMIN AUDIT] Elevated access logged")
        return prompt

class UserAudit(AgentMiddleware[None]):
    async def before_run(self, prompt, deps, ctx=None):
        print("[USER AUDIT] Standard access logged")
        return prompt

# Reusable security chain
security = MiddlewareChain([AuthMiddleware()], name="security")

# Route to different audit middleware based on runtime condition
audit = ConditionalMiddleware(
    condition=lambda ctx: ctx is not None and ctx.get("is_admin", False),
    when_true=AdminAudit(),
    when_false=UserAudit(),
)

# Combine with + operator
pipeline = security + MiddlewareChain([audit])

base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(agent=base_agent, middleware=[pipeline])

Hook Timeouts

from pydantic_ai import Agent
from pydantic_ai_middleware import MiddlewareAgent, AgentMiddleware, MiddlewareTimeout

class ExternalAPICheck(AgentMiddleware[None]):
    timeout = 3.0  # seconds — applies to every hook on this middleware

    async def before_run(self, prompt, deps, ctx=None):
        # If this takes longer than 3s, MiddlewareTimeout is raised
        result = await call_external_api(prompt)
        return prompt

base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(agent=base_agent, middleware=[ExternalAPICheck()])

try:
    result = await agent.run("Check this input")
except MiddlewareTimeout as e:
    print(f"Middleware '{e.middleware_name}' timed out in {e.hook_name} after {e.timeout}s")

Architecture

 pydantic-ai Agent              pydantic-ai-middleware
┌──────────────────┐    ┌─────────────────────────────────────────┐
│                  │    │                                         │
│  Agent(model,    │    │  MiddlewareAgent(agent, middleware)     │
│    tools,        │◄───│                                         │
│    instructions) │    │  middleware = [                         │
│                  │    │    MiddlewareChain([MW1, MW2])          │
└──────────────────┘    │    ParallelMiddleware([MW3, MW4])       │
                        │    ConditionalMiddleware(cond, MW5)     │
                        │    AsyncGuardrailMiddleware(MW6)        │
                        │  ]                                      │
                        │                                         │
                        │  + MiddlewareContext (data sharing)     │
                        │  + PermissionHandler (tool auth)        │
                        │  + PipelineSpec (config loading)        │
                        └─────────────────────────────────────────┘

Use Cases

What You Want to Build Key Components
Input Validation before_run + InputBlocked
PII Redaction before_run + after_run
Rate Limiting before_run + context + timeout
Tool Authorization before_tool_call + ToolPermissionResult
Scoped Tool Guards before_tool_call + tool_names
Tool Error Recovery on_tool_error + tool_names
Audit Logging All hooks + context
Content Moderation Parallel + AsyncGuardrail
A/B Testing ConditionalMiddleware
Config-Driven Pipelines PipelineSpec + Config Loading

Part of the Ecosystem

Package Description
pydantic-ai The foundation: Agent framework by Pydantic
pydantic-deep Full agent framework with planning, subagents, skills
pydantic-ai-backend File storage and sandbox backends
pydantic-ai-todo Task planning toolset for agents
subagents-pydantic-ai Multi-agent orchestration
summarization-pydantic-ai Context management processors

Contributing

git clone https://github.com/vstorm-co/pydantic-ai-middleware.git
cd pydantic-ai-middleware
make install
make test  # 100% coverage required
make all   # lint + typecheck + test

See CONTRIBUTING.md for full guidelines.


Star History

Star History


License

MIT — see LICENSE

Built with ❤️ by vstorm-co

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pydantic_ai_middleware-0.2.0.tar.gz (611.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

pydantic_ai_middleware-0.2.0-py3-none-any.whl (45.3 kB view details)

Uploaded Python 3

File details

Details for the file pydantic_ai_middleware-0.2.0.tar.gz.

File metadata

  • Download URL: pydantic_ai_middleware-0.2.0.tar.gz
  • Upload date:
  • Size: 611.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for pydantic_ai_middleware-0.2.0.tar.gz
Algorithm Hash digest
SHA256 d9e2308394cf4aee54bc723dcdd5292bcad24c9948c57bb7d2b685f86f3f4983
MD5 08c43ca7d8f0a415591ac9db057c9ccc
BLAKE2b-256 435aa668fc56e01db7bdab13346b37c4c3cb59052039205e2e0a0cf0e226ae93

See more details on using hashes here.

Provenance

The following attestation bundles were made for pydantic_ai_middleware-0.2.0.tar.gz:

Publisher: publish.yml on vstorm-co/pydantic-ai-middleware

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file pydantic_ai_middleware-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pydantic_ai_middleware-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 daff77183d3d68839e9c1ed3c326716f998d7c116147961e93abea4b386a114d
MD5 ef3ced428b317b4552108433553024c1
BLAKE2b-256 cdaee9f7902d3e7e46aadc0a1ce4190ec8b2d5d122d3a5854c94bfd4e7254abf

See more details on using hashes here.

Provenance

The following attestation bundles were made for pydantic_ai_middleware-0.2.0-py3-none-any.whl:

Publisher: publish.yml on vstorm-co/pydantic-ai-middleware

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page