This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 40822f2cb2fa16ce4550e72722c3bfb2c8015c51
Author: WenjinXie <wenjin...@gmail.com>
AuthorDate: Tue Aug 5 10:24:42 2025 +0800

    [api][python] Introduce Tool api and FunctionTool in python.
---
 .../api/tests/resources/tool_metadata.json         |  24 +++
 python/flink_agents/api/tests/test_tool.py         |  69 +++++++++
 python/flink_agents/api/tools/tool.py              |  93 ++++++++++-
 python/flink_agents/api/tools/utils.py             | 171 +++++++++++++++++++++
 .../tools/tool.py => plan/tests/tools/__init__.py} |  18 ---
 .../plan/tests/tools/resources/function_tool.json  |  32 ++++
 .../plan/tests/tools/test_function_tool.py         |  64 ++++++++
 python/flink_agents/plan/tools/function_tool.py    |  53 +++++--
 python/pyproject.toml                              |   1 +
 9 files changed, 495 insertions(+), 30 deletions(-)

diff --git a/python/flink_agents/api/tests/resources/tool_metadata.json 
b/python/flink_agents/api/tests/resources/tool_metadata.json
new file mode 100644
index 0000000..ca0808d
--- /dev/null
+++ b/python/flink_agents/api/tests/resources/tool_metadata.json
@@ -0,0 +1,24 @@
+{
+    "name": "foo",
+    "description": "Function for testing ToolMetadata",
+    "args_schema": {
+        "properties": {
+            "bar": {
+                "description": "The bar value.",
+                "title": "Bar",
+                "type": "integer"
+            },
+            "baz": {
+                "description": "The baz value.",
+                "title": "Baz",
+                "type": "string"
+            }
+        },
+        "required": [
+            "bar",
+            "baz"
+        ],
+        "title": "foo",
+        "type": "object"
+    }
+}
\ No newline at end of file
diff --git a/python/flink_agents/api/tests/test_tool.py 
b/python/flink_agents/api/tests/test_tool.py
new file mode 100644
index 0000000..c8d36e7
--- /dev/null
+++ b/python/flink_agents/api/tests/test_tool.py
@@ -0,0 +1,69 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import json
+from pathlib import Path
+
+import pytest
+
+from flink_agents.api.tools.tool import ToolMetadata
+from flink_agents.api.tools.utils import create_schema_from_function
+
+current_dir = Path(__file__).parent
+
+
+def foo(bar: int, baz: str) -> str:
+    """Function for testing ToolMetadata.
+
+    Parameters
+    ----------
+    bar : int
+        The bar value.
+    baz : str
+        The baz value.
+
+    Returns:
+    -------
+    str
+        Response string value.
+    """
+    raise NotImplementedError
+
+
+@pytest.fixture(scope="module")
+def tool_metadata() -> ToolMetadata:  # noqa: D103
+    return ToolMetadata(
+        name="foo",
+        description="Function for testing ToolMetadata",
+        args_schema=create_schema_from_function(name="foo", func=foo),
+    )
+
+
+def test_serialize_tool_metadata(tool_metadata: ToolMetadata) -> None:  # 
noqa: D103
+    json_value = tool_metadata.model_dump_json(serialize_as_any=True)
+    with Path(f"{current_dir}/resources/tool_metadata.json").open() as f:
+        expected_json = f.read()
+    actual = json.loads(json_value)
+    expected = json.loads(expected_json)
+    assert actual == expected
+
+
+def test_deserialize_tool_metadata(tool_metadata: ToolMetadata) -> None:  # 
noqa: D103
+    with Path(f"{current_dir}/resources/tool_metadata.json").open() as f:
+        expected_json = f.read()
+    actual_tool_metadata = tool_metadata.model_validate_json(expected_json)
+    assert actual_tool_metadata == tool_metadata
diff --git a/python/flink_agents/api/tools/tool.py 
b/python/flink_agents/api/tools/tool.py
index 566fa10..891e068 100644
--- a/python/flink_agents/api/tools/tool.py
+++ b/python/flink_agents/api/tools/tool.py
@@ -15,21 +15,108 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-from abc import ABC
+import typing
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Any, Type
 
+from pydantic import BaseModel, field_serializer, model_validator
 from typing_extensions import override
 
 from flink_agents.api.resource import ResourceType, SerializableResource
+from flink_agents.api.tools.utils import create_model_from_schema
+
+
+class ToolType(Enum):
+    """Tool type enum.
+
+    Currently, only support function tool.
+
+    Attributes:
+    ----------
+    MODEL_BUILT_IN : str
+        The tools from the model provider, like 'web_search_preview' of OpenAI 
models.
+    FUNCTION : str
+        The python/java function defined by user.
+    REMOTE_FUNCTION : str
+        The remote function indicated by name.
+    MCP : str
+        The tools provided by MCP server.
+    """
+
+    MODEL_BUILT_IN = "model_built_in"
+    FUNCTION = "function"
+    REMOTE_FUNCTION = "remote_function"
+    MCP = "mcp"
+
+
+class ToolMetadata(BaseModel):
+    """Metadata of a tools which describes what the tools does and
+     how to call the tools.
+
+    Attributes:
+    ----------
+    name : str
+        The name of the tools.
+    description : str
+        The description of the tools, tells what the tools does.
+    args_schema : Type[BaseModel]
+        The schema of the arguments passed to the tools.
+    """
+
+    name: str
+    description: str
+    args_schema: Type[BaseModel]
+
+    @field_serializer("args_schema")
+    def __serialize_args_schema(self, args_schema: Type[BaseModel]) -> 
dict[str, Any]:
+        return args_schema.model_json_schema()
+
+    @model_validator(mode="before")
+    def __custom_deserialize(self) -> "ToolMetadata":
+        args_schema = self["args_schema"]
+        if isinstance(args_schema, dict):
+            self["args_schema"] = create_model_from_schema(
+                args_schema["title"], args_schema
+            )
+        return self
+
+    def __eq__(self, other: "ToolMetadata") -> bool:
+        return (
+            other.name == self.name
+            and other.description == self.description
+            and other.args_schema.model_json_schema()
+            == self.args_schema.model_json_schema()
+        )
 
 
-#TODO: Complete BaseTool
 class BaseTool(SerializableResource, ABC):
     """Base abstract class of all kinds of tools.
 
-    Currently, this class is empty just for testing purposes
+    Attributes:
+    ----------
+    metadata : ToolMetadata
+        The metadata of the tools, includes name, description and arguments 
schema.
     """
 
+    metadata: ToolMetadata
+
     @classmethod
     @override
     def resource_type(cls) -> ResourceType:
+        """Return resource type of class."""
         return ResourceType.TOOL
+
+    @classmethod
+    @abstractmethod
+    def tool_type(cls) -> ToolType:
+        """Return tool type of class."""
+
+    @abstractmethod
+    def call(
+        self, *args: typing.Tuple[Any, ...], **kwargs: typing.Dict[str, Any]
+    ) -> Any:
+        """Call the tools with arguments.
+
+        This is the method that should be implemented by the tools' developer.
+        """
diff --git a/python/flink_agents/api/tools/utils.py 
b/python/flink_agents/api/tools/utils.py
new file mode 100644
index 0000000..a3059fc
--- /dev/null
+++ b/python/flink_agents/api/tools/utils.py
@@ -0,0 +1,171 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import typing
+from inspect import signature
+from typing import Any, Callable, Optional, Type, Union
+
+from docstring_parser import parse
+from pydantic import BaseModel, create_model
+from pydantic.fields import Field, FieldInfo
+
+
+def create_schema_from_function(name: str, func: Callable) -> Type[BaseModel]:
+    """Create a pydantic schema from a function's signature."""
+    docstr = func.__doc__
+
+    docstr = parse(docstr)
+    doc_params = {}
+    for param in docstr.params:
+        doc_params[param.arg_name] = param
+
+    fields = {}
+    params = signature(func).parameters
+    for param_name in params:
+        param_type = params[param_name].annotation
+        param_default = params[param_name].default
+        description = doc_params[param_name].description
+
+        if typing.get_origin(param_type) is typing.Annotated:
+            args = typing.get_args(param_type)
+            param_type = args[0]
+            if isinstance(args[1], str):
+                description = args[1]
+            elif isinstance(args[1], FieldInfo):
+                description = args[1].description
+
+        if param_type is params[param_name].empty:
+            param_type = typing.Any
+
+        if param_default is params[param_name].empty:
+            # Required field
+            fields[param_name] = (param_type, 
FieldInfo(description=description))
+        elif isinstance(param_default, FieldInfo):
+            # Field with pydantic.Field as default value
+            fields[param_name] = (param_type, param_default)
+        else:
+            fields[param_name] = (
+                param_type,
+                FieldInfo(default=param_default, description=description),
+            )
+
+    return create_model(name, **fields)
+
+TYPE_MAPPING: dict[str, type] = {
+    "string": str,
+    "integer": int,
+    "number": float,
+    "boolean": bool,
+    "object": dict,
+    "array": list,
+    "null": type(None),
+}
+
+CONSTRAINT_MAPPING: dict[str, str] = {
+    "minimum": "ge",
+    "maximum": "le",
+    "exclusiveMinimum": "gt",
+    "exclusiveMaximum": "lt",
+    "inclusiveMinimum": "ge",
+    "inclusiveMaximum": "le",
+    "minItems": "min_length",
+    "maxItems": "max_length",
+}
+
+
+def __get_field_params_from_field_schema(field_schema: dict) -> dict:
+    """Gets Pydantic field parameters from a JSON schema field."""
+    field_params = {}
+    for constraint, constraint_value in CONSTRAINT_MAPPING.items():
+        if constraint in field_schema:
+            field_params[constraint_value] = field_schema[constraint]
+    if "description" in field_schema:
+        field_params["description"] = field_schema["description"]
+    if "default" in field_schema:
+        field_params["default"] = field_schema["default"]
+    return field_params
+
+
+def create_model_from_schema(name: str, schema: dict) -> type[BaseModel]:
+    """Create Pydantic model from a JSON schema generated by
+    BaseModel.model_json_schema().
+    """
+    models: dict[str, type[BaseModel]] = {}
+
+    def resolve_field_type(field_schema: dict) -> type[typing.Any]:
+        """Resolves field type, including optional types and nullability."""
+        if "$ref" in field_schema:
+            model_reference = field_schema["$ref"].split("/")[-1]
+            return models.get(model_reference, Any)  # type: ignore[arg-type]
+
+        if "anyOf" in field_schema:
+            types = [
+                TYPE_MAPPING.get(t["type"], typing.Any)
+                for t in field_schema["anyOf"]
+                if t.get("type")
+            ]
+            if type(None) in types:
+                types.remove(type(None))
+                if len(types) == 1:
+                    return typing.Optional[types[0]]
+                return Optional[Union[tuple(types)]]  # type: 
ignore[return-value]
+            else:
+                return Union[tuple(types)]  # type: ignore[return-value]
+        field_type = TYPE_MAPPING.get(field_schema.get("type"), typing.Any)  # 
type: ignore[arg-type]
+
+        # Handle arrays (lists)
+        if field_schema.get("type") == "array":
+            items = field_schema.get("items", {})
+            item_type = resolve_field_type(items)
+            return list[item_type]  # type: ignore[valid-type]
+
+        # Handle objects (dicts with specified value types)
+        if field_schema.get("type") == "object":
+            additional_props = field_schema.get("additionalProperties")
+            value_type = (
+                resolve_field_type(additional_props) if additional_props else 
typing.Any
+            )
+            return dict[str, value_type]  # type: ignore[valid-type]
+
+        return field_type  # type: ignore[return-value]
+
+    # First, create models for definitions
+    definitions = schema.get("$defs", {})
+    for model_name, model_schema in definitions.items():
+        fields = {}
+        for field_name, field_schema in model_schema.get("properties", 
{}).items():
+            field_type = resolve_field_type(field_schema=field_schema)
+            field_params = 
__get_field_params_from_field_schema(field_schema=field_schema)
+            fields[field_name] = (field_type, Field(**field_params))
+
+        models[model_name] = create_model(
+            model_name, **fields, __doc__=model_schema.get("description", "")
+        )  # type: ignore[call-overload]
+
+    # Now, create the main model, resolving references
+    main_fields = {}
+    for field_name, field_schema in schema.get("properties", {}).items():
+        if "$ref" in field_schema:
+            model_reference = field_schema["$ref"].split("/")[-1]
+            field_type = models.get(model_reference, Any)  # type: 
ignore[arg-type]
+        else:
+            field_type = resolve_field_type(field_schema=field_schema)
+
+        field_params = 
__get_field_params_from_field_schema(field_schema=field_schema)
+        main_fields[field_name] = (field_type, Field(**field_params))
+
+    return create_model(name, **main_fields, __doc__=schema.get("description", 
""))
diff --git a/python/flink_agents/api/tools/tool.py 
b/python/flink_agents/plan/tests/tools/__init__.py
similarity index 68%
copy from python/flink_agents/api/tools/tool.py
copy to python/flink_agents/plan/tests/tools/__init__.py
index 566fa10..e154fad 100644
--- a/python/flink_agents/api/tools/tool.py
+++ b/python/flink_agents/plan/tests/tools/__init__.py
@@ -15,21 +15,3 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-from abc import ABC
-
-from typing_extensions import override
-
-from flink_agents.api.resource import ResourceType, SerializableResource
-
-
-#TODO: Complete BaseTool
-class BaseTool(SerializableResource, ABC):
-    """Base abstract class of all kinds of tools.
-
-    Currently, this class is empty just for testing purposes
-    """
-
-    @classmethod
-    @override
-    def resource_type(cls) -> ResourceType:
-        return ResourceType.TOOL
diff --git a/python/flink_agents/plan/tests/tools/resources/function_tool.json 
b/python/flink_agents/plan/tests/tools/resources/function_tool.json
new file mode 100644
index 0000000..52aee0d
--- /dev/null
+++ b/python/flink_agents/plan/tests/tools/resources/function_tool.json
@@ -0,0 +1,32 @@
+{
+    "name": "foo",
+    "metadata": {
+        "name": "foo",
+        "description": "Function for testing ToolMetadata.\n",
+        "args_schema": {
+            "properties": {
+                "bar": {
+                    "description": "The bar value.",
+                    "title": "Bar",
+                    "type": "integer"
+                },
+                "baz": {
+                    "description": "The baz value.",
+                    "title": "Baz",
+                    "type": "string"
+                }
+            },
+            "required": [
+                "bar",
+                "baz"
+            ],
+            "title": "foo",
+            "type": "object"
+        }
+    },
+    "func": {
+        "module": "flink_agents.plan.tests.tools.test_function_tool",
+        "qualname": "foo",
+        "func_type": "PythonFunction"
+    }
+}
\ No newline at end of file
diff --git a/python/flink_agents/plan/tests/tools/test_function_tool.py 
b/python/flink_agents/plan/tests/tools/test_function_tool.py
new file mode 100644
index 0000000..19d7147
--- /dev/null
+++ b/python/flink_agents/plan/tests/tools/test_function_tool.py
@@ -0,0 +1,64 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import json
+from pathlib import Path
+
+import pytest
+
+from flink_agents.plan.tools.function_tool import FunctionTool, from_callable
+
+current_dir = Path(__file__).parent
+
+
+def foo(bar: int, baz: str) -> str:
+    """Function for testing ToolMetadata.
+
+    Parameters
+    ----------
+    bar : int
+        The bar value.
+    baz : str
+        The baz value.
+
+    Returns:
+    -------
+    str
+        Response string value.
+    """
+    raise NotImplementedError
+
+
+@pytest.fixture(scope="module")
+def func_tool() -> FunctionTool:  # noqa: D103
+    return from_callable("foo", foo)
+
+
+def test_serialize_function_tool(func_tool: FunctionTool) -> None:  # noqa: 
D103
+    json_value = func_tool.model_dump_json(serialize_as_any=True, indent=4)
+    with Path(f"{current_dir}/resources/function_tool.json").open() as f:
+        expected_json = f.read()
+    actual = json.loads(json_value)
+    expected = json.loads(expected_json)
+    assert actual == expected
+
+
+def test_deserialize_function_tool(func_tool: FunctionTool) -> None:  # noqa: 
D103
+    with Path(f"{current_dir}/resources/function_tool.json").open() as f:
+        json_value = f.read()
+    actual_func_tool = FunctionTool.model_validate_json(json_value)
+    assert actual_func_tool == func_tool
diff --git a/python/flink_agents/plan/tools/function_tool.py 
b/python/flink_agents/plan/tools/function_tool.py
index 2847b90..f78d559 100644
--- a/python/flink_agents/plan/tools/function_tool.py
+++ b/python/flink_agents/plan/tools/function_tool.py
@@ -15,19 +15,54 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-from typing import Any, Dict, Tuple
+from typing import Any, Callable, Union
 
-from flink_agents.api.tools.tool import BaseTool
-from flink_agents.plan.function import Function
+from docstring_parser import parse
+from typing_extensions import override
+
+from flink_agents.api.tools.tool import BaseTool, ToolMetadata, ToolType
+from flink_agents.api.tools.utils import create_schema_from_function
+from flink_agents.plan.function import JavaFunction, PythonFunction
 
 
-#TODO: Complete FunctionTool
 class FunctionTool(BaseTool):
-    """Function tool.
+    """Tool that takes in a function.
 
-    Currently, this class is just for testing purposes.
+    Attributes:
+    ----------
+    func : Function
+        User defined function.
     """
-    func: Function
-    def call(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
-        """Call function."""
+
+    func: Union[PythonFunction, JavaFunction]
+
+    @classmethod
+    @override
+    def tool_type(cls) -> ToolType:
+        """Get the tool type."""
+        return ToolType.FUNCTION
+
+    def call(self, *args: Any, **kwargs: Any) -> Any:
+        """Call the function tool."""
         return self.func(*args, **kwargs)
+
+
+def from_callable(name: str, func: Callable) -> FunctionTool:
+    """Create FunctionTool from a user defined function.
+
+    Parameters
+    ----------
+    name : str
+        Name of the tool function.
+    func : Callable
+        The function to analyze.
+    """
+    description = parse(func.__doc__).description
+    metadata = ToolMetadata(
+        name=name,
+        description=description,
+        args_schema=create_schema_from_function(name=name, func=func),
+    )
+    return FunctionTool(
+        name=name, func=PythonFunction.from_callable(func), metadata=metadata
+    )
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 95487af..e6b8f83 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -46,6 +46,7 @@ classifiers = [
 dependencies = [
     "apache-flink==1.20.1",
     "pydantic==2.11.4",
+    "docstring-parser==0.16",
 ]
 
 # Optional dependencies (dependency groups)

Reply via email to