Wrap tool with confirmation handling.
We wrap the tool to intercept pydantic-ai's tool calls and add our confirmation
logic before the actual execution happens. The actual tool execution (including
moving sync functions to threads) is handled by pydantic-ai.
Current situation is: We only get all infos for tool calls for functions with
RunContext. In order to migitate this, we "fallback" to the AgentContext, which
at least provides some information.
Source code in src/llmling_agent/agent/tool_wrapping.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97 | def wrap_tool(
tool: Tool,
agent_ctx: AgentContext,
) -> Callable[..., Awaitable[Any]]:
"""Wrap tool with confirmation handling.
We wrap the tool to intercept pydantic-ai's tool calls and add our confirmation
logic before the actual execution happens. The actual tool execution (including
moving sync functions to threads) is handled by pydantic-ai.
Current situation is: We only get all infos for tool calls for functions with
RunContext. In order to migitate this, we "fallback" to the AgentContext, which
at least provides some information.
"""
original_tool = tool.callable.callable
if has_argument_type(original_tool, RunContext):
async def wrapped(ctx: RunContext[AgentContext], *args, **kwargs): # pyright: ignore
result = await agent_ctx.handle_confirmation(tool, kwargs)
# if agent_ctx.report_progress:
# await agent_ctx.report_progress(ctx.run_step, None)
match result:
case "allow":
return await execute(original_tool, ctx, *args, **kwargs)
case "skip":
msg = f"Tool {tool.name} execution skipped"
raise ToolSkippedError(msg)
case "abort_run":
msg = "Run aborted by user"
raise RunAbortedError(msg)
case "abort_chain":
msg = "Agent chain aborted by user"
raise ChainAbortedError(msg)
elif has_argument_type(original_tool, AgentContext):
async def wrapped(ctx: AgentContext, *args, **kwargs): # pyright: ignore
result = await agent_ctx.handle_confirmation(tool, kwargs)
match result:
case "allow":
return await execute(original_tool, agent_ctx, *args, **kwargs)
case "skip":
msg = f"Tool {tool.name} execution skipped"
raise ToolSkippedError(msg)
case "abort_run":
msg = "Run aborted by user"
raise RunAbortedError(msg)
case "abort_chain":
msg = "Agent chain aborted by user"
raise ChainAbortedError(msg)
else:
async def wrapped(*args, **kwargs): # pyright: ignore
result = await agent_ctx.handle_confirmation(tool, kwargs)
match result:
case "allow":
return await execute(original_tool, *args, **kwargs)
case "skip":
msg = f"Tool {tool.name} execution skipped"
raise ToolSkippedError(msg)
case "abort_run":
msg = "Run aborted by user"
raise RunAbortedError(msg)
case "abort_chain":
msg = "Agent chain aborted by user"
raise ChainAbortedError(msg)
wraps(original_tool)(wrapped) # pyright: ignore
wrapped.__doc__ = tool.description
wrapped.__name__ = tool.name
return wrapped
|