This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 40822f2cb2fa16ce4550e72722c3bfb2c8015c51 Author: WenjinXie <wenjin...@gmail.com> AuthorDate: Tue Aug 5 10:24:42 2025 +0800 [api][python] Introduce Tool api and FunctionTool in python. --- .../api/tests/resources/tool_metadata.json | 24 +++ python/flink_agents/api/tests/test_tool.py | 69 +++++++++ python/flink_agents/api/tools/tool.py | 93 ++++++++++- python/flink_agents/api/tools/utils.py | 171 +++++++++++++++++++++ .../tools/tool.py => plan/tests/tools/__init__.py} | 18 --- .../plan/tests/tools/resources/function_tool.json | 32 ++++ .../plan/tests/tools/test_function_tool.py | 64 ++++++++ python/flink_agents/plan/tools/function_tool.py | 53 +++++-- python/pyproject.toml | 1 + 9 files changed, 495 insertions(+), 30 deletions(-) diff --git a/python/flink_agents/api/tests/resources/tool_metadata.json b/python/flink_agents/api/tests/resources/tool_metadata.json new file mode 100644 index 0000000..ca0808d --- /dev/null +++ b/python/flink_agents/api/tests/resources/tool_metadata.json @@ -0,0 +1,24 @@ +{ + "name": "foo", + "description": "Function for testing ToolMetadata", + "args_schema": { + "properties": { + "bar": { + "description": "The bar value.", + "title": "Bar", + "type": "integer" + }, + "baz": { + "description": "The baz value.", + "title": "Baz", + "type": "string" + } + }, + "required": [ + "bar", + "baz" + ], + "title": "foo", + "type": "object" + } +} \ No newline at end of file diff --git a/python/flink_agents/api/tests/test_tool.py b/python/flink_agents/api/tests/test_tool.py new file mode 100644 index 0000000..c8d36e7 --- /dev/null +++ b/python/flink_agents/api/tests/test_tool.py @@ -0,0 +1,69 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import json +from pathlib import Path + +import pytest + +from flink_agents.api.tools.tool import ToolMetadata +from flink_agents.api.tools.utils import create_schema_from_function + +current_dir = Path(__file__).parent + + +def foo(bar: int, baz: str) -> str: + """Function for testing ToolMetadata. + + Parameters + ---------- + bar : int + The bar value. + baz : str + The baz value. + + Returns: + ------- + str + Response string value. + """ + raise NotImplementedError + + +@pytest.fixture(scope="module") +def tool_metadata() -> ToolMetadata: # noqa: D103 + return ToolMetadata( + name="foo", + description="Function for testing ToolMetadata", + args_schema=create_schema_from_function(name="foo", func=foo), + ) + + +def test_serialize_tool_metadata(tool_metadata: ToolMetadata) -> None: # noqa: D103 + json_value = tool_metadata.model_dump_json(serialize_as_any=True) + with Path(f"{current_dir}/resources/tool_metadata.json").open() as f: + expected_json = f.read() + actual = json.loads(json_value) + expected = json.loads(expected_json) + assert actual == expected + + +def test_deserialize_tool_metadata(tool_metadata: ToolMetadata) -> None: # noqa: D103 + with Path(f"{current_dir}/resources/tool_metadata.json").open() as f: + expected_json = f.read() + actual_tool_metadata = tool_metadata.model_validate_json(expected_json) + assert actual_tool_metadata == tool_metadata diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index 566fa10..891e068 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -15,21 +15,108 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from abc import ABC +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Type +from pydantic import BaseModel, field_serializer, model_validator from typing_extensions import override from flink_agents.api.resource import ResourceType, SerializableResource +from flink_agents.api.tools.utils import create_model_from_schema + + +class ToolType(Enum): + """Tool type enum. + + Currently, only support function tool. + + Attributes: + ---------- + MODEL_BUILT_IN : str + The tools from the model provider, like 'web_search_preview' of OpenAI models. + FUNCTION : str + The python/java function defined by user. + REMOTE_FUNCTION : str + The remote function indicated by name. + MCP : str + The tools provided by MCP server. + """ + + MODEL_BUILT_IN = "model_built_in" + FUNCTION = "function" + REMOTE_FUNCTION = "remote_function" + MCP = "mcp" + + +class ToolMetadata(BaseModel): + """Metadata of a tools which describes what the tools does and + how to call the tools. + + Attributes: + ---------- + name : str + The name of the tools. + description : str + The description of the tools, tells what the tools does. + args_schema : Type[BaseModel] + The schema of the arguments passed to the tools. + """ + + name: str + description: str + args_schema: Type[BaseModel] + + @field_serializer("args_schema") + def __serialize_args_schema(self, args_schema: Type[BaseModel]) -> dict[str, Any]: + return args_schema.model_json_schema() + + @model_validator(mode="before") + def __custom_deserialize(self) -> "ToolMetadata": + args_schema = self["args_schema"] + if isinstance(args_schema, dict): + self["args_schema"] = create_model_from_schema( + args_schema["title"], args_schema + ) + return self + + def __eq__(self, other: "ToolMetadata") -> bool: + return ( + other.name == self.name + and other.description == self.description + and other.args_schema.model_json_schema() + == self.args_schema.model_json_schema() + ) -#TODO: Complete BaseTool class BaseTool(SerializableResource, ABC): """Base abstract class of all kinds of tools. - Currently, this class is empty just for testing purposes + Attributes: + ---------- + metadata : ToolMetadata + The metadata of the tools, includes name, description and arguments schema. """ + metadata: ToolMetadata + @classmethod @override def resource_type(cls) -> ResourceType: + """Return resource type of class.""" return ResourceType.TOOL + + @classmethod + @abstractmethod + def tool_type(cls) -> ToolType: + """Return tool type of class.""" + + @abstractmethod + def call( + self, *args: typing.Tuple[Any, ...], **kwargs: typing.Dict[str, Any] + ) -> Any: + """Call the tools with arguments. + + This is the method that should be implemented by the tools' developer. + """ diff --git a/python/flink_agents/api/tools/utils.py b/python/flink_agents/api/tools/utils.py new file mode 100644 index 0000000..a3059fc --- /dev/null +++ b/python/flink_agents/api/tools/utils.py @@ -0,0 +1,171 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import typing +from inspect import signature +from typing import Any, Callable, Optional, Type, Union + +from docstring_parser import parse +from pydantic import BaseModel, create_model +from pydantic.fields import Field, FieldInfo + + +def create_schema_from_function(name: str, func: Callable) -> Type[BaseModel]: + """Create a pydantic schema from a function's signature.""" + docstr = func.__doc__ + + docstr = parse(docstr) + doc_params = {} + for param in docstr.params: + doc_params[param.arg_name] = param + + fields = {} + params = signature(func).parameters + for param_name in params: + param_type = params[param_name].annotation + param_default = params[param_name].default + description = doc_params[param_name].description + + if typing.get_origin(param_type) is typing.Annotated: + args = typing.get_args(param_type) + param_type = args[0] + if isinstance(args[1], str): + description = args[1] + elif isinstance(args[1], FieldInfo): + description = args[1].description + + if param_type is params[param_name].empty: + param_type = typing.Any + + if param_default is params[param_name].empty: + # Required field + fields[param_name] = (param_type, FieldInfo(description=description)) + elif isinstance(param_default, FieldInfo): + # Field with pydantic.Field as default value + fields[param_name] = (param_type, param_default) + else: + fields[param_name] = ( + param_type, + FieldInfo(default=param_default, description=description), + ) + + return create_model(name, **fields) + +TYPE_MAPPING: dict[str, type] = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "object": dict, + "array": list, + "null": type(None), +} + +CONSTRAINT_MAPPING: dict[str, str] = { + "minimum": "ge", + "maximum": "le", + "exclusiveMinimum": "gt", + "exclusiveMaximum": "lt", + "inclusiveMinimum": "ge", + "inclusiveMaximum": "le", + "minItems": "min_length", + "maxItems": "max_length", +} + + +def __get_field_params_from_field_schema(field_schema: dict) -> dict: + """Gets Pydantic field parameters from a JSON schema field.""" + field_params = {} + for constraint, constraint_value in CONSTRAINT_MAPPING.items(): + if constraint in field_schema: + field_params[constraint_value] = field_schema[constraint] + if "description" in field_schema: + field_params["description"] = field_schema["description"] + if "default" in field_schema: + field_params["default"] = field_schema["default"] + return field_params + + +def create_model_from_schema(name: str, schema: dict) -> type[BaseModel]: + """Create Pydantic model from a JSON schema generated by + BaseModel.model_json_schema(). + """ + models: dict[str, type[BaseModel]] = {} + + def resolve_field_type(field_schema: dict) -> type[typing.Any]: + """Resolves field type, including optional types and nullability.""" + if "$ref" in field_schema: + model_reference = field_schema["$ref"].split("/")[-1] + return models.get(model_reference, Any) # type: ignore[arg-type] + + if "anyOf" in field_schema: + types = [ + TYPE_MAPPING.get(t["type"], typing.Any) + for t in field_schema["anyOf"] + if t.get("type") + ] + if type(None) in types: + types.remove(type(None)) + if len(types) == 1: + return typing.Optional[types[0]] + return Optional[Union[tuple(types)]] # type: ignore[return-value] + else: + return Union[tuple(types)] # type: ignore[return-value] + field_type = TYPE_MAPPING.get(field_schema.get("type"), typing.Any) # type: ignore[arg-type] + + # Handle arrays (lists) + if field_schema.get("type") == "array": + items = field_schema.get("items", {}) + item_type = resolve_field_type(items) + return list[item_type] # type: ignore[valid-type] + + # Handle objects (dicts with specified value types) + if field_schema.get("type") == "object": + additional_props = field_schema.get("additionalProperties") + value_type = ( + resolve_field_type(additional_props) if additional_props else typing.Any + ) + return dict[str, value_type] # type: ignore[valid-type] + + return field_type # type: ignore[return-value] + + # First, create models for definitions + definitions = schema.get("$defs", {}) + for model_name, model_schema in definitions.items(): + fields = {} + for field_name, field_schema in model_schema.get("properties", {}).items(): + field_type = resolve_field_type(field_schema=field_schema) + field_params = __get_field_params_from_field_schema(field_schema=field_schema) + fields[field_name] = (field_type, Field(**field_params)) + + models[model_name] = create_model( + model_name, **fields, __doc__=model_schema.get("description", "") + ) # type: ignore[call-overload] + + # Now, create the main model, resolving references + main_fields = {} + for field_name, field_schema in schema.get("properties", {}).items(): + if "$ref" in field_schema: + model_reference = field_schema["$ref"].split("/")[-1] + field_type = models.get(model_reference, Any) # type: ignore[arg-type] + else: + field_type = resolve_field_type(field_schema=field_schema) + + field_params = __get_field_params_from_field_schema(field_schema=field_schema) + main_fields[field_name] = (field_type, Field(**field_params)) + + return create_model(name, **main_fields, __doc__=schema.get("description", "")) diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/plan/tests/tools/__init__.py similarity index 68% copy from python/flink_agents/api/tools/tool.py copy to python/flink_agents/plan/tests/tools/__init__.py index 566fa10..e154fad 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/plan/tests/tools/__init__.py @@ -15,21 +15,3 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from abc import ABC - -from typing_extensions import override - -from flink_agents.api.resource import ResourceType, SerializableResource - - -#TODO: Complete BaseTool -class BaseTool(SerializableResource, ABC): - """Base abstract class of all kinds of tools. - - Currently, this class is empty just for testing purposes - """ - - @classmethod - @override - def resource_type(cls) -> ResourceType: - return ResourceType.TOOL diff --git a/python/flink_agents/plan/tests/tools/resources/function_tool.json b/python/flink_agents/plan/tests/tools/resources/function_tool.json new file mode 100644 index 0000000..52aee0d --- /dev/null +++ b/python/flink_agents/plan/tests/tools/resources/function_tool.json @@ -0,0 +1,32 @@ +{ + "name": "foo", + "metadata": { + "name": "foo", + "description": "Function for testing ToolMetadata.\n", + "args_schema": { + "properties": { + "bar": { + "description": "The bar value.", + "title": "Bar", + "type": "integer" + }, + "baz": { + "description": "The baz value.", + "title": "Baz", + "type": "string" + } + }, + "required": [ + "bar", + "baz" + ], + "title": "foo", + "type": "object" + } + }, + "func": { + "module": "flink_agents.plan.tests.tools.test_function_tool", + "qualname": "foo", + "func_type": "PythonFunction" + } +} \ No newline at end of file diff --git a/python/flink_agents/plan/tests/tools/test_function_tool.py b/python/flink_agents/plan/tests/tools/test_function_tool.py new file mode 100644 index 0000000..19d7147 --- /dev/null +++ b/python/flink_agents/plan/tests/tools/test_function_tool.py @@ -0,0 +1,64 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import json +from pathlib import Path + +import pytest + +from flink_agents.plan.tools.function_tool import FunctionTool, from_callable + +current_dir = Path(__file__).parent + + +def foo(bar: int, baz: str) -> str: + """Function for testing ToolMetadata. + + Parameters + ---------- + bar : int + The bar value. + baz : str + The baz value. + + Returns: + ------- + str + Response string value. + """ + raise NotImplementedError + + +@pytest.fixture(scope="module") +def func_tool() -> FunctionTool: # noqa: D103 + return from_callable("foo", foo) + + +def test_serialize_function_tool(func_tool: FunctionTool) -> None: # noqa: D103 + json_value = func_tool.model_dump_json(serialize_as_any=True, indent=4) + with Path(f"{current_dir}/resources/function_tool.json").open() as f: + expected_json = f.read() + actual = json.loads(json_value) + expected = json.loads(expected_json) + assert actual == expected + + +def test_deserialize_function_tool(func_tool: FunctionTool) -> None: # noqa: D103 + with Path(f"{current_dir}/resources/function_tool.json").open() as f: + json_value = f.read() + actual_func_tool = FunctionTool.model_validate_json(json_value) + assert actual_func_tool == func_tool diff --git a/python/flink_agents/plan/tools/function_tool.py b/python/flink_agents/plan/tools/function_tool.py index 2847b90..f78d559 100644 --- a/python/flink_agents/plan/tools/function_tool.py +++ b/python/flink_agents/plan/tools/function_tool.py @@ -15,19 +15,54 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Dict, Tuple +from typing import Any, Callable, Union -from flink_agents.api.tools.tool import BaseTool -from flink_agents.plan.function import Function +from docstring_parser import parse +from typing_extensions import override + +from flink_agents.api.tools.tool import BaseTool, ToolMetadata, ToolType +from flink_agents.api.tools.utils import create_schema_from_function +from flink_agents.plan.function import JavaFunction, PythonFunction -#TODO: Complete FunctionTool class FunctionTool(BaseTool): - """Function tool. + """Tool that takes in a function. - Currently, this class is just for testing purposes. + Attributes: + ---------- + func : Function + User defined function. """ - func: Function - def call(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: - """Call function.""" + + func: Union[PythonFunction, JavaFunction] + + @classmethod + @override + def tool_type(cls) -> ToolType: + """Get the tool type.""" + return ToolType.FUNCTION + + def call(self, *args: Any, **kwargs: Any) -> Any: + """Call the function tool.""" return self.func(*args, **kwargs) + + +def from_callable(name: str, func: Callable) -> FunctionTool: + """Create FunctionTool from a user defined function. + + Parameters + ---------- + name : str + Name of the tool function. + func : Callable + The function to analyze. + """ + description = parse(func.__doc__).description + metadata = ToolMetadata( + name=name, + description=description, + args_schema=create_schema_from_function(name=name, func=func), + ) + return FunctionTool( + name=name, func=PythonFunction.from_callable(func), metadata=metadata + ) diff --git a/python/pyproject.toml b/python/pyproject.toml index 95487af..e6b8f83 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -46,6 +46,7 @@ classifiers = [ dependencies = [ "apache-flink==1.20.1", "pydantic==2.11.4", + "docstring-parser==0.16", ] # Optional dependencies (dependency groups)