31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191 | @dataclass
class ToolCodeGenerator:
"""Generates code artifacts for a single tool."""
schema: OpenAIFunctionTool
"""Schema of the tool."""
callable: Callable
"""Tool to generate code for."""
name_override: str | None = None
"""Name of the tool."""
@classmethod
def from_tool(cls, tool: Tool) -> ToolCodeGenerator:
"""Create a ToolCodeGenerator from a Tool."""
return cls(schema=tool.schema, callable=tool.callable, name_override=tool.name)
@property
def name(self) -> str:
"""Name of the tool."""
return self.name_override or self.callable.__name__
def _extract_basic_signature(self, return_type: str = "Any") -> str:
"""Fallback signature extraction from tool schema."""
schema = self.schema["function"]
params = schema.get("parameters", {}).get("properties", {})
required = set(schema.get("required", [])) # type: ignore
param_strs = []
for name, param_info in params.items():
# Skip context parameters that should be hidden from users
if self._is_context_parameter(name):
continue
# Use improved type inference
type_hint = self._infer_parameter_type(name, param_info)
if name not in required:
param_strs.append(f"{name}: {type_hint} = None")
else:
param_strs.append(f"{name}: {type_hint}")
return f"{self.name}({', '.join(param_strs)}) -> {return_type}"
def _infer_parameter_type(self, param_name: str, param_info: Property) -> str:
"""Infer parameter type from schema and function inspection."""
schema_type = param_info.get("type", "Any")
# If schema has a specific type, use it
if schema_type != "object":
return TYPE_MAP.get(schema_type, "Any")
# For 'object' type, try to infer from function signature
try:
callable_func = self.callable
# Use wrapped signature if available (for context parameter hiding)
sig = getattr(callable_func, "__signature__", None) or inspect.signature(
callable_func
)
if param_name in sig.parameters:
param = sig.parameters[param_name]
# Try annotation first
if param.annotation != inspect.Parameter.empty:
if hasattr(param.annotation, "__name__"):
return param.annotation.__name__
return str(param.annotation)
# Infer from default value
if param.default != inspect.Parameter.empty:
default_type = type(param.default).__name__
# Map common types
if default_type in ["int", "float", "str", "bool"]:
return default_type
# If no default and it's required, assume str for web-like functions
required = set(
self.schema.get("function", {})
.get("parameters", {})
.get("required", [])
)
if param_name in required:
return "str"
except Exception: # noqa: BLE001
pass
# Fallback to Any for unresolved object types
return "Any"
def _get_return_model_name(self) -> str:
"""Get the return model name for a tool."""
try:
schema = create_schema(self.callable)
if schema.returns.get("type") == "object":
return f"{self.name.title()}Response"
if schema.returns.get("type") == "array":
return f"list[{self.name.title()}Item]"
return TYPE_MAP.get(schema.returns.get("type", "string"), "Any")
except Exception: # noqa: BLE001
return "Any"
def get_function_signature(self) -> str:
"""Extract function signature using schemez."""
try:
return_model_name = self._get_return_model_name()
return self._extract_basic_signature(return_model_name)
except Exception: # noqa: BLE001
return self._extract_basic_signature("Any")
def _get_callable_signature(self) -> inspect.Signature:
"""Get signature from callable, respecting wrapped signatures."""
# Use wrapped signature if available (for context parameter hiding)
return getattr(self.callable, "__signature__", None) or inspect.signature(
self.callable
)
def _is_context_parameter(self, param_name: str) -> bool: # noqa: PLR0911
"""Check if a parameter is a context parameter that should be hidden."""
try:
sig = self._get_callable_signature()
if param_name not in sig.parameters:
return False
param = sig.parameters[param_name]
if param.annotation == inspect.Parameter.empty:
return False
# Check if parameter is RunContext or AgentContext
annotation = param.annotation
annotation_str = str(annotation)
# Handle RunContext (including parameterized like RunContext[None])
if annotation is RunContext:
return True
# Check for parameterized RunContext using string matching
if "RunContext" in annotation_str:
return True
# Handle AgentContext
if hasattr(annotation, "__name__") and annotation.__name__ == "AgentContext":
return True
except Exception: # noqa: BLE001
return False
else:
return "AgentContext" in annotation_str
def generate_return_model(self) -> str | None:
try:
schema = create_schema(self.callable)
if schema.returns.get("type") not in {"object", "array"}:
return None
class_name = f"{self.name.title()}Response"
model_code = schema.to_pydantic_model_code(class_name=class_name)
return model_code.strip() or None
except Exception: # noqa: BLE001
return None
|