Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,25 @@ for tool in agent.tools:
3. Functions can optionally take the `context` (must be the first argument). You can also set overrides, like the name of the tool, description, which docstring style to use, etc.
4. You can pass the decorated functions to the list of tools.

You can also decorate instance methods. Access the tool from an instance before passing it to
`Agent.tools`; the implicit `self` parameter is bound to that instance and omitted from the tool
schema.

```python
class CustomerTools:
def __init__(self, tenant_id: str) -> None:
self.tenant_id = tenant_id

@function_tool
def lookup_customer(self, customer_id: str) -> str:
"""Look up a customer by ID."""
return f"{self.tenant_id}:{customer_id}"


customer_tools = CustomerTools("tenant_123")
agent = Agent(name="Assistant", tools=[customer_tools.lookup_customer])
```

??? note "Expand to see output"

```
Expand Down
21 changes: 18 additions & 3 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class FuncSchema:
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
omitted_parameter_names: tuple[str, ...] = ()
"""Parameter names that are supplied by the SDK instead of model-generated JSON."""

def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
Expand All @@ -52,6 +54,8 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:

# Use enumerate() so we can skip the first parameter if it's context.
for idx, (name, param) in enumerate(self.signature.parameters.items()):
if name in self.omitted_parameter_names:
continue
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
if self.takes_context and idx == 0:
continue
Expand Down Expand Up @@ -228,6 +232,7 @@ def function_schema(
description_override: str | None = None,
use_docstring_info: bool = True,
strict_json_schema: bool = True,
skip_first_parameter: bool = False,
) -> FuncSchema:
"""
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
Expand All @@ -246,6 +251,8 @@ def function_schema(
the schema adheres to the "strict" standard the OpenAI API expects. We **strongly**
recommend setting this to True, as it increases the likelihood of the LLM producing
correct JSON input.
skip_first_parameter: If True, omit the first signature parameter from the tool schema and
call arguments. This is used for instance methods decorated with `@function_tool`.

Returns:
A `FuncSchema` object containing the function's name, description, parameter descriptions,
Expand Down Expand Up @@ -288,22 +295,29 @@ def function_schema(
params = list(sig.parameters.items())
takes_context = False
filtered_params = []
omitted_parameter_names: list[str] = []

params_to_check = params
if skip_first_parameter and params:
omitted_parameter_names.append(params[0][0])
params_to_check = params[1:]

if params:
first_name, first_param = params[0]
if params_to_check:
first_name, first_param = params_to_check[0]
# Prefer the evaluated type hint if available
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True # Mark that the function takes context
omitted_parameter_names.append(first_name)
else:
filtered_params.append((first_name, first_param))
else:
filtered_params.append((first_name, first_param))

# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
for name, param in params_to_check[1:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
Expand Down Expand Up @@ -421,4 +435,5 @@ def function_schema(
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
omitted_parameter_names=tuple(omitted_parameter_names),
)
50 changes: 41 additions & 9 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ class FunctionTool:
_emit_tool_origin: bool = field(default=True, kw_only=True, repr=False)
"""Whether runtime item generation should emit tool origin metadata for this tool."""

_method_tool_factory: Callable[[Any], FunctionTool] | None = field(
default=None,
kw_only=True,
repr=False,
)
"""Internal descriptor hook used for instance methods decorated with `@function_tool`."""

def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool:
if instance is None or self._method_tool_factory is None:
return self
return self._method_tool_factory(instance)

@property
def qualified_name(self) -> str:
"""Return the public qualified name used to identify this function tool."""
Expand Down Expand Up @@ -1827,18 +1839,33 @@ def function_tool(
explicitly loads it.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
def _is_instance_method_tool(the_func: ToolFunction[...]) -> bool:
parameters = tuple(inspect.signature(the_func).parameters.values())
return bool(parameters) and parameters[0].name == "self"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Detect methods without relying on self

For decorated instance methods whose first parameter uses any valid name other than the convention self (for example def lookup(this, account_id: str)), this check does not install the descriptor factory, so tools.lookup returns the unbound FunctionTool, leaves the instance parameter in the JSON schema, and later invokes the original function without the instance. Conversely, a non-method tool whose first argument happens to be named self is now misclassified and becomes unusable unless accessed through an instance.

Useful? React with 👍 / 👎.


def _create_function_tool(
the_func: ToolFunction[...],
*,
method_tool_instance: Any | None = None,
) -> FunctionTool:
is_sync_function_tool = not inspect.iscoroutinefunction(the_func)
is_instance_method_tool = _is_instance_method_tool(the_func)
schema = function_schema(
func=the_func,
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
skip_first_parameter=is_instance_method_tool,
)

async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
if is_instance_method_tool and method_tool_instance is None:
raise UserError(
f"Instance method tool {schema.name} must be accessed from an instance"
)

tool_name = ctx.tool_name
json_data = _parse_function_tool_json_input(tool_name=tool_name, input_json=input)
_log_function_tool_invocation(tool_name=tool_name, input_json=input)
Expand All @@ -1857,16 +1884,16 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
if not _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")

leading_args: list[Any] = []
if is_instance_method_tool:
leading_args.append(method_tool_instance)
if schema.takes_context:
leading_args.append(ctx)

if not is_sync_function_tool:
if schema.takes_context:
result = await the_func(ctx, *args, **kwargs_dict)
else:
result = await the_func(*args, **kwargs_dict)
result = await the_func(*leading_args, *args, **kwargs_dict)
else:
if schema.takes_context:
result = await asyncio.to_thread(the_func, ctx, *args, **kwargs_dict)
else:
result = await asyncio.to_thread(the_func, *args, **kwargs_dict)
result = await asyncio.to_thread(the_func, *leading_args, *args, **kwargs_dict)

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Tool {tool_name} completed.")
Expand Down Expand Up @@ -1897,6 +1924,11 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
defer_loading=defer_loading,
sync_invoker=is_sync_function_tool,
)
if is_instance_method_tool and method_tool_instance is None:
function_tool._method_tool_factory = lambda instance: _create_function_tool(
the_func,
method_tool_instance=instance,
)
return function_tool

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
50 changes: 50 additions & 0 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,56 @@ async def test_simple_function():
)


@pytest.mark.asyncio
async def test_instance_method_function_tool_binds_self():
class AccountTools:
def __init__(self, prefix: str) -> None:
self.prefix = prefix

@function_tool
def lookup(self, account_id: str) -> str:
"""Look up an account."""
return f"{self.prefix}:{account_id}"

tools = AccountTools("acct")
tool = tools.lookup

assert isinstance(AccountTools.lookup, FunctionTool)
assert tool.name == "lookup"
assert "self" not in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "acct:123"


@pytest.mark.asyncio
async def test_instance_method_function_tool_supports_context_after_self():
class AccountTools:
@function_tool
def lookup(self, ctx: ToolContext[str], account_id: str) -> str:
"""Look up an account with context."""
return f"{ctx.context}:{account_id}"

tools = AccountTools()
tool = tools.lookup

assert "self" not in tool.params_json_schema["properties"]
assert "ctx" not in tool.params_json_schema["properties"]
assert "account_id" in tool.params_json_schema["properties"]

result = await tool.on_invoke_tool(
ToolContext("tenant", tool_name=tool.name, tool_call_id="1", tool_arguments=""),
'{"account_id": "123"}',
)

assert result == "tenant:123"


@pytest.mark.asyncio
async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None:
calls = {"to_thread": 0, "func": 0}
Expand Down
Loading