Simple middleware library for Pydantic-AI - before/after hooks without imposed guardrails structure
Project description
Pydantic AI Middleware
Intercept, Transform & Guard Every Pydantic AI Call
7 Lifecycle Hooks • Parallel Execution • Async Guardrails • Fully Type-Safe
Get Started in 60 Seconds
pip install pydantic-ai-middleware
from pydantic_ai import Agent
from pydantic_ai_middleware import MiddlewareAgent, AgentMiddleware, InputBlocked
class ContentFilter(AgentMiddleware[None]):
"""Block dangerous prompts before they reach the LLM."""
async def before_run(self, prompt, deps, ctx=None):
if "ignore all instructions" in prompt.lower():
raise InputBlocked("Prompt injection attempt blocked")
return prompt
async def after_run(self, prompt, output, deps, ctx=None):
# Redact sensitive patterns from the LLM response
return output.replace("SSN:", "[REDACTED]")
# Wrap any pydantic-ai Agent with middleware
base_agent = Agent("openai:gpt-4o", instructions="You are a helpful assistant.")
agent = MiddlewareAgent(agent=base_agent, middleware=[ContentFilter()])
result = await agent.run("Hello, how are you?")
print(result.output)
That's it. Your pydantic-ai agent now has input validation, output filtering, and prompt injection protection — all with a simple wrapper.
Why Middleware, Not Guardrails?
pydantic-ai-middleware takes a different approach from traditional guardrails libraries:
| Aspect | Middleware (this library) | Traditional Guardrails |
|---|---|---|
| Complexity | Low | High |
| Structure | No imposed structure | Fixed result types, actions |
| Flexibility | Maximum | Constrained by design |
| Learning curve | Flat | Steeper |
| Built-in guardrails | None (you build what you need) | Pre-built (PII, moderation) |
| Parallel execution | Built-in with early cancellation | Often built-in |
| Type safety | Full generics support | Varies |
You decide what to build. Logging, guardrails, metrics, rate limiting, PII redaction — all using the same simple API.
Features
- 7 Lifecycle Hooks —
before_run,after_run,before_model_request,before_tool_call,on_tool_error,after_tool_call,on_error - Parallel Execution — Run multiple middleware concurrently with 4 aggregation strategies and early cancellation
- Async Guardrails — Run guardrails alongside LLM calls with
BLOCKING,CONCURRENT, orASYNC_POSTtiming - Middleware Chains — Compose middleware into reusable, ordered sequences with
+operator - Conditional Routing — Route to different middleware based on runtime conditions
- Config Loading — Build pipelines from JSON/YAML configuration files
- Decorator Syntax — Create middleware from simple decorated functions
- Context Sharing — Share data between hooks with access control
- Tool Name Filtering — Scope middleware to specific tools with
tool_names - Hook Timeouts — Per-middleware timeout enforcement with
MiddlewareTimeout - Permission Decisions — Structured ALLOW/DENY/ASK protocol for tool authorization
- Cost Tracking — Automatic token usage and USD cost monitoring with budget limits
- Lightweight — Only requires pydantic-ai-slim and genai-prices
Hook Lifecycle
| Hook | When Called | Can Modify |
|---|---|---|
before_run |
Before agent starts | Prompt |
after_run |
After agent finishes | Output |
before_model_request |
Before each model call | Messages |
before_tool_call |
Before tool execution | Tool arguments |
on_tool_error |
When a tool raises an exception | Exception (replace or re-raise) |
after_tool_call |
After tool execution | Tool result |
on_error |
When error occurs | Exception |
Real-World Examples
Input Validation + Rate Limiting
import time
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
AgentMiddleware,
InputBlocked,
MiddlewareContext,
)
class RateLimiter(AgentMiddleware[None]):
"""Limit how many requests per minute."""
def __init__(self, max_per_minute: int = 10):
self.max_per_minute = max_per_minute
self._timestamps: list[float] = []
async def before_run(self, prompt, deps, ctx=None):
now = time.time()
self._timestamps = [t for t in self._timestamps if now - t < 60]
if len(self._timestamps) >= self.max_per_minute:
raise InputBlocked("Rate limit exceeded — try again later")
self._timestamps.append(now)
return prompt
class PromptSanitizer(AgentMiddleware[None]):
"""Remove potentially harmful instructions from prompts."""
BLOCKED_PATTERNS = ["ignore previous", "system prompt", "jailbreak"]
async def before_run(self, prompt, deps, ctx=None):
lower = prompt.lower() if isinstance(prompt, str) else str(prompt).lower()
for pattern in self.BLOCKED_PATTERNS:
if pattern in lower:
raise InputBlocked(f"Blocked pattern detected: {pattern}")
return prompt
# Build the agent with middleware pipeline
base_agent = Agent("openai:gpt-4o", instructions="You are a customer support agent.")
agent = MiddlewareAgent(
agent=base_agent,
middleware=[RateLimiter(max_per_minute=20), PromptSanitizer()],
context=MiddlewareContext(), # enable cross-middleware data sharing
)
result = await agent.run("How do I reset my password?")
Tool Authorization with Permission Decisions
from pydantic_ai import Agent, RunContext
from pydantic_ai_middleware import (
MiddlewareAgent,
AgentMiddleware,
ToolDecision,
ToolPermissionResult,
)
# Define a pydantic-ai agent with tools
base_agent = Agent("openai:gpt-4o", instructions="You are a file manager.")
@base_agent.tool
async def read_file(ctx: RunContext[None], path: str) -> str:
"""Read a file from disk."""
return f"Contents of {path}"
@base_agent.tool
async def delete_file(ctx: RunContext[None], path: str) -> str:
"""Delete a file from disk."""
return f"Deleted {path}"
# Middleware that controls tool access
class FileAccessControl(AgentMiddleware[None]):
"""Require explicit approval for destructive file operations."""
tool_names = {"delete_file"} # only intercept delete_file
async def before_tool_call(self, tool_name, tool_args, deps, ctx=None):
return ToolPermissionResult(
decision=ToolDecision.ASK,
reason=f"Agent wants to delete: {tool_args.get('path')}",
)
# Permission handler — called when middleware returns ASK
async def approval_callback(tool_name: str, tool_args: dict, reason: str) -> bool:
print(f"[APPROVAL REQUIRED] {reason}")
response = input("Allow? (y/n): ")
return response.lower() == "y"
agent = MiddlewareAgent(
agent=base_agent,
middleware=[FileAccessControl()],
permission_handler=approval_callback,
)
# read_file works without approval, delete_file triggers the callback
result = await agent.run("Read config.yaml then delete temp.log")
Structured Output with Audit Logging
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
AgentMiddleware,
MiddlewareContext,
ScopedContext,
)
class SupportTicket(BaseModel):
category: str
priority: str
summary: str
class AuditLogger(AgentMiddleware[None]):
"""Log all agent interactions for compliance."""
async def before_run(self, prompt, deps, ctx=None):
if ctx:
ctx.set("input_prompt", prompt)
print(f"[AUDIT] Input: {prompt[:80]}...")
return prompt
async def after_run(self, prompt, output, deps, ctx=None):
print(f"[AUDIT] Output type: {type(output).__name__}")
return output
async def before_tool_call(self, tool_name, tool_args, deps, ctx=None):
print(f"[AUDIT] Tool call: {tool_name}({tool_args})")
return tool_args
# Agent with structured output — middleware works transparently
base_agent = Agent(
"openai:gpt-4o",
instructions="Classify support tickets.",
output_type=SupportTicket,
)
agent = MiddlewareAgent(
agent=base_agent,
middleware=[AuditLogger()],
context=MiddlewareContext(),
)
result = await agent.run("My payment failed and I can't access my account")
ticket: SupportTicket = result.output
print(f"Category: {ticket.category}, Priority: {ticket.priority}")
Parallel Validators with Async Guardrails
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
AgentMiddleware,
ParallelMiddleware,
AsyncGuardrailMiddleware,
AggregationStrategy,
GuardrailTiming,
InputBlocked,
)
class ProfanityFilter(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
# Simulated check — replace with real classifier
if any(word in prompt.lower() for word in ["badword"]):
raise InputBlocked("Profanity detected")
return prompt
class PIIDetector(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
import re
if re.search(r"\b\d{3}-\d{2}-\d{4}\b", prompt):
raise InputBlocked("SSN detected in input")
return prompt
class ToxicityChecker(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
# Simulated slow ML-based check
import asyncio
await asyncio.sleep(0.5)
return prompt
# Run ProfanityFilter + PIIDetector in parallel (fast, both must pass)
fast_validators = ParallelMiddleware(
middleware=[ProfanityFilter(), PIIDetector()],
strategy=AggregationStrategy.ALL_MUST_PASS,
)
# Run ToxicityChecker concurrently with the LLM (saves latency)
toxicity_guard = AsyncGuardrailMiddleware(
guardrail=ToxicityChecker(),
timing=GuardrailTiming.CONCURRENT,
cancel_on_failure=True,
)
base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(
agent=base_agent,
middleware=[fast_validators, toxicity_guard],
)
result = await agent.run("Summarize this document for me")
Decorator Syntax for Quick Middleware
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
before_run,
after_run,
before_tool_call,
on_tool_error,
)
@before_run
async def log_input(prompt, deps, ctx=None):
print(f">>> {prompt}")
return prompt
@after_run
async def log_output(prompt, output, deps, ctx=None):
print(f"<<< {output}")
return output
@before_tool_call(tools={"web_search"})
async def validate_search(tool_name, tool_args, deps, ctx=None):
"""Only runs for web_search tool, skipped for all others."""
query = tool_args.get("query", "")
if len(query) > 500:
tool_args["query"] = query[:500]
return tool_args
@on_tool_error(tools={"web_search"})
async def handle_search_error(tool_name, tool_args, error, deps, ctx=None):
if isinstance(error, TimeoutError):
return ConnectionError("Search service temporarily unavailable")
return None # re-raise original error
base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(
agent=base_agent,
middleware=[log_input, log_output, validate_search, handle_search_error],
)
result = await agent.run("Search for the latest Python release notes")
Cost Tracking with Budget Limits
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
MiddlewareContext,
CostTrackingMiddleware,
create_cost_tracking_middleware,
)
cost_mw = create_cost_tracking_middleware(
model_name="openai:gpt-4.1",
budget_limit_usd=5.0,
on_cost_update=lambda info: print(
f"Run #{info.run_count}: ${info.run_cost_usd:.4f} "
f"(total: ${info.total_cost_usd:.4f})"
),
)
base_agent = Agent("openai:gpt-4.1")
agent = MiddlewareAgent(
agent=base_agent,
middleware=[cost_mw],
context=MiddlewareContext(), # required for cost tracking
)
result = await agent.run("Explain quantum computing")
# Run #1: $0.0023 (total: $0.0023)
print(f"Total tokens: {cost_mw.total_request_tokens + cost_mw.total_response_tokens}")
Middleware Chains + Conditional Routing
from pydantic_ai import Agent
from pydantic_ai_middleware import (
MiddlewareAgent,
MiddlewareChain,
ConditionalMiddleware,
AgentMiddleware,
)
class AuthMiddleware(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
if ctx:
ctx.set("authenticated", True)
return prompt
class AdminAudit(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
print("[ADMIN AUDIT] Elevated access logged")
return prompt
class UserAudit(AgentMiddleware[None]):
async def before_run(self, prompt, deps, ctx=None):
print("[USER AUDIT] Standard access logged")
return prompt
# Reusable security chain
security = MiddlewareChain([AuthMiddleware()], name="security")
# Route to different audit middleware based on runtime condition
audit = ConditionalMiddleware(
condition=lambda ctx: ctx is not None and ctx.get("is_admin", False),
when_true=AdminAudit(),
when_false=UserAudit(),
)
# Combine with + operator
pipeline = security + MiddlewareChain([audit])
base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(agent=base_agent, middleware=[pipeline])
Hook Timeouts
from pydantic_ai import Agent
from pydantic_ai_middleware import MiddlewareAgent, AgentMiddleware, MiddlewareTimeout
class ExternalAPICheck(AgentMiddleware[None]):
timeout = 3.0 # seconds — applies to every hook on this middleware
async def before_run(self, prompt, deps, ctx=None):
# If this takes longer than 3s, MiddlewareTimeout is raised
result = await call_external_api(prompt)
return prompt
base_agent = Agent("openai:gpt-4o")
agent = MiddlewareAgent(agent=base_agent, middleware=[ExternalAPICheck()])
try:
result = await agent.run("Check this input")
except MiddlewareTimeout as e:
print(f"Middleware '{e.middleware_name}' timed out in {e.hook_name} after {e.timeout}s")
Architecture
pydantic-ai Agent pydantic-ai-middleware
┌──────────────────┐ ┌─────────────────────────────────────────┐
│ │ │ │
│ Agent(model, │ │ MiddlewareAgent(agent, middleware) │
│ tools, │◄───│ │
│ instructions) │ │ middleware = [ │
│ │ │ MiddlewareChain([MW1, MW2]) │
└──────────────────┘ │ ParallelMiddleware([MW3, MW4]) │
│ ConditionalMiddleware(cond, MW5) │
│ AsyncGuardrailMiddleware(MW6) │
│ ] │
│ │
│ + MiddlewareContext (data sharing) │
│ + PermissionHandler (tool auth) │
│ + PipelineSpec (config loading) │
└─────────────────────────────────────────┘
Use Cases
| What You Want to Build | Key Components |
|---|---|
| Input Validation | before_run + InputBlocked |
| PII Redaction | before_run + after_run |
| Rate Limiting | before_run + context + timeout |
| Tool Authorization | before_tool_call + ToolPermissionResult |
| Scoped Tool Guards | before_tool_call + tool_names |
| Tool Error Recovery | on_tool_error + tool_names |
| Audit Logging | All hooks + context |
| Content Moderation | Parallel + AsyncGuardrail |
| A/B Testing | ConditionalMiddleware |
| Cost Tracking | CostTrackingMiddleware + BudgetExceededError |
| Config-Driven Pipelines | PipelineSpec + Config Loading |
Part of the Ecosystem
| Package | Description |
|---|---|
| pydantic-ai | The foundation: Agent framework by Pydantic |
| pydantic-deep | Full agent framework with planning, subagents, skills |
| pydantic-ai-backend | File storage and sandbox backends |
| pydantic-ai-todo | Task planning toolset for agents |
| subagents-pydantic-ai | Multi-agent orchestration |
| summarization-pydantic-ai | Context management processors |
Contributing
git clone https://github.com/vstorm-co/pydantic-ai-middleware.git
cd pydantic-ai-middleware
make install
make test # 100% coverage required
make all # lint + typecheck + test
See CONTRIBUTING.md for full guidelines.
Star History
License
MIT — see LICENSE
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file pydantic_ai_middleware-0.2.3.tar.gz.
File metadata
- Download URL: pydantic_ai_middleware-0.2.3.tar.gz
- Upload date:
- Size: 502.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
f1c0241410e77875d3c49aa5fbe534cb91e73dd5d5ce05ccfa678a96d7c57c26
|
|
| MD5 |
0dd4cfbd5bda6fc45dc96131680a018f
|
|
| BLAKE2b-256 |
faa0685b02dbc69074215a696389cac83b28744eeff52c500b1b2c1d15536c4a
|
Provenance
The following attestation bundles were made for pydantic_ai_middleware-0.2.3.tar.gz:
Publisher:
publish.yml on vstorm-co/pydantic-ai-middleware
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pydantic_ai_middleware-0.2.3.tar.gz -
Subject digest:
f1c0241410e77875d3c49aa5fbe534cb91e73dd5d5ce05ccfa678a96d7c57c26 - Sigstore transparency entry: 1088513036
- Sigstore integration time:
-
Permalink:
vstorm-co/pydantic-ai-middleware@3d6351ff95fb60bf487269af6ece81457354db38 -
Branch / Tag:
refs/tags/0.2.3 - Owner: https://github.com/vstorm-co
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3d6351ff95fb60bf487269af6ece81457354db38 -
Trigger Event:
release
-
Statement type:
File details
Details for the file pydantic_ai_middleware-0.2.3-py3-none-any.whl.
File metadata
- Download URL: pydantic_ai_middleware-0.2.3-py3-none-any.whl
- Upload date:
- Size: 49.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
3f571075e74c2c1ddd6dc6c14726d769f2b7a0940d30833817b09ccbdff275a5
|
|
| MD5 |
e08c5f03a5346a553e8041f0d2e4ce8e
|
|
| BLAKE2b-256 |
6cfda21c30db3363cae7126d0efe3e3907f9b5a759c8db4c71fe2c21ed56a634
|
Provenance
The following attestation bundles were made for pydantic_ai_middleware-0.2.3-py3-none-any.whl:
Publisher:
publish.yml on vstorm-co/pydantic-ai-middleware
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
pydantic_ai_middleware-0.2.3-py3-none-any.whl -
Subject digest:
3f571075e74c2c1ddd6dc6c14726d769f2b7a0940d30833817b09ccbdff275a5 - Sigstore transparency entry: 1088513108
- Sigstore integration time:
-
Permalink:
vstorm-co/pydantic-ai-middleware@3d6351ff95fb60bf487269af6ece81457354db38 -
Branch / Tag:
refs/tags/0.2.3 - Owner: https://github.com/vstorm-co
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3d6351ff95fb60bf487269af6ece81457354db38 -
Trigger Event:
release
-
Statement type: