This is an automated email from the ASF dual-hosted git repository.

wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new 81e55ca3 [test][runtime] Strengthen live-LLM e2e tests with structured 
tool-invocation assertions (#722)
81e55ca3 is described below

commit 81e55ca304c08f9f0dd026a835d99e537be2e7b4
Author: Weiqing Yang <[email protected]>
AuthorDate: Tue Jun 9 01:54:26 2026 -0700

    [test][runtime] Strengthen live-LLM e2e tests with structured 
tool-invocation assertions (#722)
---
 .../e2e_tests_integration/react_agent_test.py      |  31 +++-
 .../chat_model_cross_language_test.py              |  13 +-
 .../yaml_cross_language_test.py                    |  21 ++-
 python/flink_agents/e2e_tests/test_utils.py        | 115 +++++++++++++++
 .../flink_agents/e2e_tests/test_utils_unit_test.py | 150 +++++++++++++++++++
 .../runtime/local_execution_environment.py         |   5 +
 python/flink_agents/runtime/local_runner.py        |  18 +++
 .../tests/test_local_execution_environment.py      | 163 ++++++++++++++++++++-
 8 files changed, 503 insertions(+), 13 deletions(-)

diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py 
b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
index 04d0be41..f0c77178 100644
--- a/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
+++ b/python/flink_agents/e2e_tests/e2e_tests_integration/react_agent_test.py
@@ -44,7 +44,12 @@ from flink_agents.api.resource import (
 )
 from flink_agents.api.tools.tool import Tool
 from flink_agents.e2e_tests.e2e_tests_integration.react_agent_tools import 
add, multiply
-from flink_agents.e2e_tests.test_utils import pull_model
+from flink_agents.e2e_tests.test_utils import (
+    assert_tool_invoked,
+    collect_tool_invocations,
+    pull_model,
+    tool_invocations_from_events,
+)
 
 current_dir = Path(__file__).parent
 
@@ -132,7 +137,13 @@ def test_react_agent_on_local_runner(monkeypatch: 
pytest.MonkeyPatch) -> None:
     assert len(output_list) == 1, (
         "This may be caused by the LLM response does not match the output 
schema, you can rerun this case."
     )
-    assert output_list[0]["0001"].result == 1386528
+    assert int(output_list[0]["0001"].result) == 1386528
+
+    # multiply's first arg (4444 = 2123 + 2321) proves the addition was 
computed
+    # correctly and the multiply tool was used; the model often does the 
addition
+    # without the add tool, so add is not a reliable signal to assert on.
+    invocations = tool_invocations_from_events(env.get_tool_request_events())
+    assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})
 
 
 @pytest.mark.skipif(
@@ -149,7 +160,7 @@ def test_react_agent_on_remote_runner(
     t_env = 
StreamTableEnvironment.create(stream_execution_environment=stream_env)
 
     table = t_env.from_elements(
-        elements=[(1, 2, 3)],
+        elements=[(2123, 2321, 312)],
         schema=DataTypes.ROW(
             [
                 DataTypes.FIELD("a", DataTypes.INT()),
@@ -169,6 +180,10 @@ def test_react_agent_on_remote_runner(
 
     env.get_config().set(AgentExecutionOptions.MAX_RETRIES, 3)
 
+    log_dir = tmp_path / "event_logs"
+    log_dir.mkdir(parents=True, exist_ok=True)
+    env.get_config().set_str("baseLogDir", str(log_dir))
+
     # register resource to execution environment
     (
         env.add_resource(
@@ -243,4 +258,12 @@ def test_react_agent_on_remote_runner(
     assert len(actual_result) == 1, (
         "This may be caused by the LLM response does not match the output 
schema, you can rerun this case."
     )
-    assert "result" in json.loads(actual_result[0].strip())
+    assert json.loads(actual_result[0].strip())["result"] == 1386528
+
+    # multiply's first arg (4444 = 2123 + 2321) proves the addition was 
computed
+    # correctly and threaded into multiply; the model often does the addition
+    # without the add tool, so add is not a reliable signal to assert on. This
+    # exercises the same reasoning chain as the local-runner test, but read 
back
+    # through the event-log capture path.
+    invocations = collect_tool_invocations(log_dir)
+    assert_tool_invoked(invocations, "multiply", {"a": 4444, "b": 312})
diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
 
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
index 8af10c3f..ece74eca 100644
--- 
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
+++ 
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/chat_model_cross_language_test.py
@@ -36,7 +36,11 @@ from flink_agents.api.execution_environment import 
AgentsExecutionEnvironment
 from 
flink_agents.e2e_tests.e2e_tests_resource_cross_language.chat_model_cross_language_agent
 import (
     ChatModelCrossLanguageAgent,
 )
-from flink_agents.e2e_tests.test_utils import pull_model
+from flink_agents.e2e_tests.test_utils import (
+    assert_tool_invoked,
+    collect_tool_invocations,
+    pull_model,
+)
 
 current_dir = Path(__file__).parent
 
@@ -72,6 +76,9 @@ def test_java_chat_model_integration(
     deserialize_datastream = input_datastream.map(lambda x: str(x))
 
     agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
+    log_dir = tmp_path / "event_logs"
+    log_dir.mkdir(parents=True, exist_ok=True)
+    agents_env.get_config().set_str("baseLogDir", str(log_dir))
     output_datastream = (
         agents_env.from_datastream(
             input=deserialize_datastream, key_selector=lambda x: "orderKey"
@@ -106,6 +113,8 @@ def test_java_chat_model_integration(
             with file.open() as f:
                 actual_result.extend(f.readlines())
 
+    invocations = collect_tool_invocations(log_dir)
+    assert_tool_invoked(invocations, "add", {"a": 1, "b": 2})
+
     joined = "\n".join(actual_result).lower()
-    assert "3" in joined, f"math answer missing '3': {actual_result!r}"
     assert "cat" in joined, f"creative answer missing 'cat': {actual_result!r}"
diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py
 
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py
index ee9ed231..f26b23c4 100644
--- 
a/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py
+++ 
b/python/flink_agents/e2e_tests/e2e_tests_resource_cross_language/yaml_cross_language_test.py
@@ -46,7 +46,11 @@ from pyflink.datastream.connectors.file_system import (
 )
 
 from flink_agents.api.execution_environment import AgentsExecutionEnvironment
-from flink_agents.e2e_tests.test_utils import pull_model
+from flink_agents.e2e_tests.test_utils import (
+    assert_tool_invoked,
+    collect_tool_invocations,
+    pull_model,
+)
 
 current_dir = Path(__file__).parent
 _RESOURCES = current_dir.parent / "resources"
@@ -116,6 +120,9 @@ def test_yaml_cross_language_agent(
     deserialize_datastream = input_datastream.map(lambda x: str(x))
 
     agents_env = AgentsExecutionEnvironment.get_execution_environment(env=env)
+    log_dir = tmp_path / "event_logs"
+    log_dir.mkdir(parents=True, exist_ok=True)
+    agents_env.get_config().set_str("baseLogDir", str(log_dir))
     agents_env.load_yaml(_RESOURCES / "yaml_cross_language_agent.yaml")
 
     output_datastream = (
@@ -152,12 +159,16 @@ def test_yaml_cross_language_agent(
             with file.open() as f:
                 actual_result.extend(f.readlines())
 
-    # Math path went through the Java ``calculateBMI`` tool:
-    # 70 / (1.75 * 1.75) ≈ 22.86, so the final answer should mention 22.
-    # Creative path doesn't use any tool.
+    # Math path went through the Java ``calculateBMI`` tool, called with the
+    # weight/height parsed from the input ("1.75 meters tall and weighs 70 
kg").
+    assert_tool_invoked(
+        collect_tool_invocations(log_dir),
+        "calculateBMI",
+        {"weightKg": 70, "heightM": 1.75},
+    )
+    # Creative path doesn't use any tool; its answer mentions a cat.
     # NOTE: We join all results and search without relying on order, because
     # StreamingFileSink may produce multiple part files and iterdir() does not
     # guarantee a deterministic traversal order across platforms.
     joined = "\n".join(actual_result).lower()
-    assert "22" in joined, f"math answer missing '22': {actual_result!r}"
     assert "cat" in joined, f"creative answer missing 'cat': {actual_result!r}"
diff --git a/python/flink_agents/e2e_tests/test_utils.py 
b/python/flink_agents/e2e_tests/test_utils.py
index eab34a07..7e8e7505 100644
--- a/python/flink_agents/e2e_tests/test_utils.py
+++ b/python/flink_agents/e2e_tests/test_utils.py
@@ -15,14 +15,129 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
+import json
 import subprocess
 from pathlib import Path
 
 from ollama import Client
 
+from flink_agents.api.events.tool_event import ToolRequestEvent
+
 current_dir = Path(__file__).parent
 
 
+def _normalize_arguments(arguments: object) -> dict:
+    """Return tool-call arguments as a dict, parsing a JSON string if needed.
+
+    Args:
+        arguments: Tool-call arguments, either a mapping (Ollama path), a
+            JSON-encoded string (some providers ``json.dumps`` the arguments),
+            or ``None`` for a no-argument tool call.
+
+    Returns:
+        The arguments as a dict; an empty dict when ``arguments`` is ``None``.
+    """
+    if arguments is None:
+        return {}
+    if isinstance(arguments, str):
+        return json.loads(arguments)
+    return dict(arguments)
+
+
+def collect_tool_invocations(log_dir: str | Path) -> list[dict]:
+    """Read ``events-*.log`` under ``log_dir`` and return tool invocations in 
order.
+
+    Globs the per-subtask event-log files the ``FileEventLogger`` writes, 
parses
+    each JSONL record, and extracts every ``_tool_request_event`` tool call. 
The
+    tool-call dict is nested under ``function`` in the wire format.
+
+    Args:
+        log_dir: Directory containing the ``events-*.log`` files (the 
configured
+            ``baseLogDir``).
+
+    Returns:
+        Ordered list of ``{"name": str, "arguments": dict | str}``. Empty when 
the
+        model invoked no tool (a legitimate, assertable outcome).
+    """
+    invocations = []
+    for log_file in sorted(Path(log_dir).glob("events-*.log")):
+        with log_file.open() as handle:
+            for line in handle:
+                if not line.strip():
+                    continue
+                record = json.loads(line)
+                if record.get("eventType") != "_tool_request_event":
+                    continue
+                tool_calls = record["event"]["attributes"].get("tool_calls", 
[])
+                for tool_call in tool_calls:
+                    function = tool_call["function"]
+                    invocations.append(
+                        {
+                            "name": function["name"],
+                            "arguments": function["arguments"],
+                        }
+                    )
+    return invocations
+
+
+def tool_invocations_from_events(events: list[ToolRequestEvent]) -> list[dict]:
+    """Normalize live ``ToolRequestEvent`` objects to the same invocation 
shape.
+
+    Adapts the in-memory capture (the ``LocalRunner`` hook) to the same
+    ``{name, arguments}`` shape :func:`collect_tool_invocations` returns from 
the
+    event log, so both sources feed :func:`assert_tool_invoked` identically. 
Each
+    event's ``tool_calls`` is a list of nested ``{id, type, function:{name,
+    arguments}}`` dicts; order is preserved.
+
+    Args:
+        events: ``ToolRequestEvent`` objects captured during a local run.
+
+    Returns:
+        Ordered list of ``{"name": str, "arguments": dict | str}``, one per 
tool
+        call across all events.
+    """
+    invocations = []
+    for event in events:
+        for tool_call in event.tool_calls:
+            function = tool_call["function"]
+            invocations.append(
+                {
+                    "name": function["name"],
+                    "arguments": function["arguments"],
+                }
+            )
+    return invocations
+
+
+def assert_tool_invoked(invocations: list[dict], name: str, arguments: dict) 
-> None:
+    """Assert some invocation called tool ``name`` with arguments equal to 
``arguments``.
+
+    Argument values are compared after normalizing both sides to a dict (a
+    JSON-string ``arguments`` is parsed first), so the comparison is
+    order-independent and tolerant of providers that encode arguments as a 
string.
+
+    Args:
+        invocations: Tool invocations as returned by 
:func:`collect_tool_invocations`.
+        name: Expected tool name.
+        arguments: Expected tool arguments.
+
+    Raises:
+        AssertionError: If no invocation matches both ``name`` and 
``arguments``;
+            the message dumps the actual invocations.
+    """
+    expected_args = _normalize_arguments(arguments)
+    for invocation in invocations:
+        if invocation["name"] != name:
+            continue
+        if _normalize_arguments(invocation["arguments"]) == expected_args:
+            return
+    message = (
+        f"No invocation of tool {name!r} with arguments {expected_args!r}; "
+        f"actual invocations: {invocations!r}"
+    )
+    raise AssertionError(message)
+
+
 def pull_model(ollama_model: str) -> Client:
     """Run ollama pull ollama_model."""
     try:
diff --git a/python/flink_agents/e2e_tests/test_utils_unit_test.py 
b/python/flink_agents/e2e_tests/test_utils_unit_test.py
new file mode 100644
index 00000000..2236abd7
--- /dev/null
+++ b/python/flink_agents/e2e_tests/test_utils_unit_test.py
@@ -0,0 +1,150 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import json
+from pathlib import Path
+
+import pytest
+
+from flink_agents.api.events.tool_event import ToolRequestEvent
+from flink_agents.e2e_tests.test_utils import (
+    assert_tool_invoked,
+    collect_tool_invocations,
+    tool_invocations_from_events,
+)
+
+
+def _tool_request_record(tool_calls: list) -> dict:
+    """Build a record matching the FileEventLogger wire format for a tool 
request."""
+    return {
+        "timestamp": "2026-05-31T00:00:00Z",
+        "logLevel": "STANDARD",
+        "eventType": "_tool_request_event",
+        "event": {
+            "eventType": "_tool_request_event",
+            "id": "00000000-0000-0000-0000-000000000001",
+            "attributes": {"model": "qwen3:1.7b", "tool_calls": tool_calls},
+        },
+    }
+
+
+def _function_tool_call(name: str, arguments: object) -> dict:
+    return {
+        "id": "00000000-0000-0000-0000-0000000000aa",
+        "type": "function",
+        "function": {"name": name, "arguments": arguments},
+    }
+
+
+def _write_log(log_dir: Path, records: list) -> None:
+    log_dir.mkdir(parents=True, exist_ok=True)
+    with (log_dir / "events-0.log").open("w") as handle:
+        for record in records:
+            handle.write(json.dumps(record) + "\n")
+
+
+def test_collect_tool_invocations(tmp_path: Path) -> None:
+    """Parser extracts nested function.name/function.arguments from a tool 
request line.
+
+    Records with a different eventType are ignored.
+    """
+    log_dir = tmp_path / "event_logs"
+    other_event = {
+        "timestamp": "2026-05-31T00:00:00Z",
+        "logLevel": "STANDARD",
+        "eventType": "_input_event",
+        "event": {"eventType": "_input_event", "id": "x", "attributes": {}},
+    }
+    _write_log(
+        log_dir,
+        [
+            other_event,
+            _tool_request_record([_function_tool_call("add", {"a": 1, "b": 
2})]),
+        ],
+    )
+
+    assert collect_tool_invocations(log_dir) == [
+        {"name": "add", "arguments": {"a": 1, "b": 2}}
+    ]
+
+
+def test_collect_tool_invocations_no_tool(tmp_path: Path) -> None:
+    """A run with no tool request event yields an empty list."""
+    log_dir = tmp_path / "event_logs"
+    _write_log(
+        log_dir,
+        [
+            {
+                "timestamp": "2026-05-31T00:00:00Z",
+                "logLevel": "STANDARD",
+                "eventType": "_output_event",
+                "event": {
+                    "eventType": "_output_event",
+                    "id": "y",
+                    "attributes": {},
+                },
+            }
+        ],
+    )
+
+    assert collect_tool_invocations(log_dir) == []
+
+
+def test_assert_tool_invoked_dict_args() -> None:
+    """Passes when an invocation matches name and dict args 
(order-independent)."""
+    invocations = [{"name": "add", "arguments": {"b": 2, "a": 1}}]
+    assert_tool_invoked(invocations, "add", {"a": 1, "b": 2})
+
+
+def test_assert_tool_invoked_json_string_args() -> None:
+    """Passes when the recorded arguments are a JSON string rather than a 
dict."""
+    invocations = [{"name": "add", "arguments": '{"a": 1, "b": 2}'}]
+    assert_tool_invoked(invocations, "add", {"a": 1, "b": 2})
+
+
+def test_assert_tool_invoked_none_args() -> None:
+    """A no-arg tool call recorded as ``None`` matches an expected empty 
dict."""
+    invocations = [{"name": "now", "arguments": None}]
+    assert_tool_invoked(invocations, "now", {})
+
+
+def test_assert_tool_invoked_mismatch_reports_invocations() -> None:
+    """Raises AssertionError dumping the actual invocations on a mismatch."""
+    invocations = [{"name": "add", "arguments": {"a": 9, "b": 9}}]
+    with pytest.raises(AssertionError) as exc_info:
+        assert_tool_invoked(invocations, "add", {"a": 1, "b": 2})
+    assert "add" in str(exc_info.value)
+    assert "9" in str(exc_info.value)
+
+
+def test_tool_invocations_from_events() -> None:
+    """Live ToolRequestEvents normalize to the same {name, arguments} shape.
+
+    One event carrying two tool calls yields one invocation per call, in order.
+    """
+    event = ToolRequestEvent(
+        model="qwen3:1.7b",
+        tool_calls=[
+            _function_tool_call("add", {"a": 1, "b": 2}),
+            _function_tool_call("multiply", {"a": 4444, "b": 312}),
+        ],
+    )
+
+    assert tool_invocations_from_events([event]) == [
+        {"name": "add", "arguments": {"a": 1, "b": 2}},
+        {"name": "multiply", "arguments": {"a": 4444, "b": 312}},
+    ]
diff --git a/python/flink_agents/runtime/local_execution_environment.py 
b/python/flink_agents/runtime/local_execution_environment.py
index 89096307..c9768fc3 100644
--- a/python/flink_agents/runtime/local_execution_environment.py
+++ b/python/flink_agents/runtime/local_execution_environment.py
@@ -22,6 +22,7 @@ from pyflink.datastream import DataStream, KeySelector, 
StreamExecutionEnvironme
 from pyflink.table import Schema, StreamTableEnvironment, Table
 
 from flink_agents.api.agents.agent import Agent
+from flink_agents.api.events.tool_event import ToolRequestEvent
 from flink_agents.api.execution_environment import (
     AgentBuilder,
     AgentsExecutionEnvironment,
@@ -141,6 +142,10 @@ class 
LocalExecutionEnvironment(AgentsExecutionEnvironment):
         for output in outputs:
             self.__output.append(output)
 
+    def get_tool_request_events(self) -> List[ToolRequestEvent]:
+        """Get the ToolRequestEvents captured by the runner during 
execution."""
+        return self.__runner.get_tool_request_events()
+
     def from_datastream(
         self, input: DataStream, key_selector: KeySelector | Callable | None = 
None
     ) -> AgentBuilder:
diff --git a/python/flink_agents/runtime/local_runner.py 
b/python/flink_agents/runtime/local_runner.py
index 078fd7f1..38436a2c 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -26,6 +26,7 @@ from typing_extensions import override
 
 from flink_agents.api.agents.agent import Agent
 from flink_agents.api.events.event import Event, InputEvent, OutputEvent
+from flink_agents.api.events.tool_event import ToolRequestEvent
 from flink_agents.api.memory.long_term_memory import BaseLongTermMemory
 from flink_agents.api.memory_object import MemoryObject, MemoryType
 from flink_agents.api.metric_group import MetricGroup
@@ -273,6 +274,8 @@ class LocalRunner(AgentRunner):
         Dictionary of active contexts indexed by key.
     __outputs:
         Outputs generated by agent execution.
+    __tool_request_events:
+        ToolRequestEvents observed during agent execution.
     __config:
         Internal configration.
     """
@@ -280,6 +283,7 @@ class LocalRunner(AgentRunner):
     __agent_plan: Any
     __keyed_contexts: Dict[Any, LocalRunnerContext]
     __outputs: List[Dict[str, Any]]
+    __tool_request_events: List[ToolRequestEvent]
     __config: AgentConfiguration
 
     def __init__(self, agent: Agent, config: AgentConfiguration) -> None:
@@ -295,6 +299,7 @@ class LocalRunner(AgentRunner):
         self.__agent_plan = AgentPlan.from_agent(agent, config)
         self.__keyed_contexts = {}
         self.__outputs = []
+        self.__tool_request_events = []
         self.__config = config
 
     @override
@@ -340,6 +345,9 @@ class LocalRunner(AgentRunner):
             if isinstance(event, OutputEvent):
                 self.__outputs.append({key: event.output})
                 continue
+            if isinstance(event, ToolRequestEvent):
+                self.__tool_request_events.append(event)
+                # Fall through: the request must still dispatch to its action.
             event_type = event.get_type()
             for action in self.__agent_plan.get_actions(event_type):
                 logger.info("key: %s, performing action: %s", key, action.name)
@@ -367,3 +375,13 @@ class LocalRunner(AgentRunner):
             The agent execution outputs.
         """
         return self.__outputs
+
+    def get_tool_request_events(self) -> List[ToolRequestEvent]:
+        """Get the ToolRequestEvents captured during agent execution.
+
+        Returns:
+        -------
+        List[ToolRequestEvent]
+            The ToolRequestEvents observed in the run loop, in order.
+        """
+        return self.__tool_request_events
diff --git 
a/python/flink_agents/runtime/tests/test_local_execution_environment.py 
b/python/flink_agents/runtime/tests/test_local_execution_environment.py
index 51250c0c..1c861b35 100644
--- a/python/flink_agents/runtime/tests/test_local_execution_environment.py
+++ b/python/flink_agents/runtime/tests/test_local_execution_environment.py
@@ -16,15 +16,29 @@
 # limitations under the License.
 
#################################################################################
 import time
-from typing import ClassVar
+import uuid
+from typing import Any, ClassVar, Dict, List, Sequence
 
 import pytest
 
 from flink_agents.api.agents.agent import Agent
-from flink_agents.api.decorators import action
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.chat_models.chat_model import (
+    BaseChatModelConnection,
+    BaseChatModelSetup,
+)
+from flink_agents.api.decorators import (
+    action,
+    chat_model_connection,
+    chat_model_setup,
+    tool,
+)
+from flink_agents.api.events.chat_event import ChatRequestEvent, 
ChatResponseEvent
 from flink_agents.api.events.event import Event, InputEvent, OutputEvent
 from flink_agents.api.execution_environment import AgentsExecutionEnvironment
+from flink_agents.api.resource import ResourceDescriptor, ResourceType
 from flink_agents.api.runner_context import RunnerContext
+from flink_agents.api.tools.tool import ToolType
 
 
 class Agent1(Agent):
@@ -218,3 +232,148 @@ def test_mixed_event_workflow() -> None:
     env.execute()
 
     assert output_list == [{"bob": "done:42"}]
+
+
+# ── Tool-request capture hook (Track B) ──────────────────────────────────
+
+
+class _ToolConnection(BaseChatModelConnection):
+    """Mock connection emitting a single ``add`` tool call."""
+
+    def chat(
+        self,
+        messages: Sequence[ChatMessage],
+        tools: List | None = None,
+        **kwargs: Any,
+    ) -> ChatMessage:
+        """Emit an ``add`` tool call, then echo the tool result as content."""
+        last = messages[-1]
+        if last.role == MessageRole.TOOL:
+            return ChatMessage(role=MessageRole.ASSISTANT, 
content=str(last.content))
+        tool_call = {
+            "id": str(uuid.uuid4()),
+            "type": ToolType.FUNCTION,
+            "function": {"name": "add", "arguments": {"a": 1, "b": 2}},
+        }
+        return ChatMessage(
+            role=MessageRole.ASSISTANT, content="", tool_calls=[tool_call]
+        )
+
+
+class _ToolChatModel(BaseChatModelSetup):
+    """Mock setup binding the ``add`` tool to the connection."""
+
+    def open(self) -> None:
+        """Do nothing."""
+
+    @property
+    def model_kwargs(self) -> Dict[str, Any]:
+        """Return model kwargs."""
+        return {}
+
+    def chat(
+        self,
+        messages: Sequence[ChatMessage],
+        prompt_args: Dict[str, Any] | None = None,
+        **kwargs: Any,
+    ) -> ChatMessage:
+        """Bind tools and delegate to the connection."""
+        server = self.resource_context.get_resource(
+            self.connection, ResourceType.CHAT_MODEL_CONNECTION
+        )
+        tools = [
+            self.resource_context.get_resource(name, ResourceType.TOOL)
+            for name in (self.tools or [])
+        ]
+        return server.chat(messages, tools=tools, **kwargs)
+
+
+class ToolRequestAgent(Agent):
+    """Agent whose chat model emits a ToolRequestEvent dispatched to ``add``.
+
+    The InputEvent action sends a ChatRequestEvent; the mock chat model returns
+    an ``add`` tool call, which the built-in chat/tool actions turn into a real
+    ToolRequestEvent flowing through the runner. The ToolRequestEvent is 
captured
+    by the runner AND still dispatched to ``tool_call_action`` — the final 
output
+    (the tool result) proves capture did not swallow the event.
+    """
+
+    @chat_model_connection
+    @staticmethod
+    def conn() -> ResourceDescriptor:
+        """Mock chat model connection."""
+        return ResourceDescriptor(
+            clazz=f"{_ToolConnection.__module__}.{_ToolConnection.__name__}"
+        )
+
+    @chat_model_setup
+    @staticmethod
+    def model() -> ResourceDescriptor:
+        """Mock chat model bound to the ``add`` tool."""
+        return ResourceDescriptor(
+            clazz=f"{_ToolChatModel.__module__}.{_ToolChatModel.__name__}",
+            connection="conn",
+            model="mock-model",
+            tools=["add"],
+        )
+
+    @tool
+    @staticmethod
+    def add(a: int, b: int) -> int:
+        """Return the sum of a and b.
+
+        Parameters
+        ----------
+        a : int
+            The first operand.
+        b : int
+            The second operand.
+
+        Returns:
+        -------
+        int:
+            The sum of a and b.
+        """
+        return a + b
+
+    @action(InputEvent.EVENT_TYPE)
+    @staticmethod
+    def process_input(event: Event, ctx: RunnerContext) -> None:
+        """Send a ChatRequestEvent to drive the tool-calling flow."""
+        input = InputEvent.from_event(event).input
+        ctx.send_event(
+            ChatRequestEvent(
+                model="model",
+                messages=[ChatMessage(role=MessageRole.USER, content=input)],
+            )
+        )
+
+    @action(ChatResponseEvent.EVENT_TYPE)
+    @staticmethod
+    def process_response(event: Event, ctx: RunnerContext) -> None:
+        """Emit the final assistant content as output."""
+        response = ChatResponseEvent.from_event(event).response
+        ctx.send_event(OutputEvent(output=response.content))
+
+
+def test_local_runner_captures_tool_request_events() -> None:
+    """A ToolRequestEvent is captured AND still dispatched to its action."""
+    env = AgentsExecutionEnvironment.get_execution_environment()
+
+    input_list = []
+    agent = ToolRequestAgent()
+
+    output_list = env.from_list(input_list).apply(agent).to_list()
+
+    input_list.append({"key": "0001", "value": "add 1 and 2"})
+    env.execute()
+
+    captured = env.get_tool_request_events()
+    assert len(captured) == 1
+    assert captured[0].tool_calls[0]["function"] == {
+        "name": "add",
+        "arguments": {"a": 1, "b": 2},
+    }
+    # Dispatch was not swallowed: tool_call_action ran, producing the tool 
result
+    # that the model echoed back as the final output.
+    assert output_list == [{"0001": "3"}]

Reply via email to