This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 51c150206cbce38f45ba32cf89a0ce5cf9857d5d
Author: WenjinXie <[email protected]>
AuthorDate: Thu Dec 25 15:58:37 2025 +0800

    [plan] Refactor built-in chat action.
    
    fix
    
    [plan] Refactor built-in chat action in java.
---
 .../flink/agents/plan/actions/ChatModelAction.java | 233 ++++++++++++--------
 .../flink_agents/plan/actions/chat_model_action.py | 241 +++++++++++++--------
 2 files changed, 292 insertions(+), 182 deletions(-)

diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index 2b153ae..0fb96e3 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -50,6 +50,93 @@ public class ChatModelAction {
                 List.of(ChatRequestEvent.class.getName(), 
ToolResponseEvent.class.getName()));
     }
 
+    @SuppressWarnings("unchecked")
+    private static List<ChatMessage> updateToolCallContext(
+            MemoryObject sensoryMem,
+            UUID initialRequestId,
+            List<ChatMessage> initialMessages,
+            List<ChatMessage> addedMessages)
+            throws Exception {
+
+        Map<UUID, Object> toolCallContext;
+        if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) {
+            toolCallContext = (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
+        } else {
+            toolCallContext = new HashMap<>();
+        }
+        if (!toolCallContext.containsKey(initialRequestId)) {
+            toolCallContext.put(initialRequestId, initialMessages);
+        }
+        List<ChatMessage> messageContext =
+                new ArrayList<>((List<ChatMessage>) 
toolCallContext.get(initialRequestId));
+
+        messageContext.addAll(addedMessages);
+        toolCallContext.put(initialRequestId, messageContext);
+        sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
+        return messageContext;
+    }
+
+    @SuppressWarnings("unchecked")
+    private static void clearToolCallContext(MemoryObject sensoryMem, UUID 
initialRequestId)
+            throws Exception {
+        if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) {
+            Map<UUID, Object> toolCallContext =
+                    (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
+            if (toolCallContext.containsKey(initialRequestId)) {
+                toolCallContext.remove(initialRequestId);
+                sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    private static void saveToolRequestEventContext(
+            MemoryObject sensoryMem, UUID toolRequestEventId, UUID 
initialRequestId, String model)
+            throws Exception {
+        Map<UUID, Object> toolRequestEventContext;
+        if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) {
+            toolRequestEventContext =
+                    (Map<UUID, Object>) 
sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
+        } else {
+            toolRequestEventContext = new HashMap<>();
+        }
+        toolRequestEventContext.put(
+                toolRequestEventId, Map.of(INITIAL_REQUEST_ID, 
initialRequestId, MODEL, model));
+        sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
+    }
+
+    @SuppressWarnings("unchecked")
+    private static Map<String, Object> removeToolRequestEventContext(
+            MemoryObject sensoryMem, UUID requestId) throws Exception {
+        Map<UUID, Object> toolRequestEventContext =
+                (Map<UUID, Object>) 
sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
+        Map<String, Object> context =
+                (Map<String, Object>) 
toolRequestEventContext.remove(requestId);
+        sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext);
+        return context;
+    }
+
+    private static void handleToolCalls(
+            ChatMessage response,
+            UUID initialRequestId,
+            String model,
+            List<ChatMessage> messages,
+            RunnerContext ctx)
+            throws Exception {
+        updateToolCallContext(
+                ctx.getSensoryMemory(),
+                initialRequestId,
+                messages,
+                Collections.singletonList(response));
+
+        ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, 
response.getToolCalls());
+
+        saveToolRequestEventContext(
+                ctx.getSensoryMemory(), toolRequestEvent.getId(), 
initialRequestId, model);
+
+        ctx.sendEvent(toolRequestEvent);
+    }
+
     /**
      * Chat with chat model.
      *
@@ -67,54 +154,65 @@ public class ChatModelAction {
                 (BaseChatModelSetup) ctx.getResource(model, 
ResourceType.CHAT_MODEL);
 
         ChatMessage response = chatModel.chat(messages, Map.of());
-        MemoryObject sensoryMem = ctx.getSensoryMemory();
 
         if (!response.getToolCalls().isEmpty()) {
-            Map<UUID, Object> toolCallContext;
-            if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) {
-                toolCallContext = (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
-            } else {
-                toolCallContext = new HashMap<>();
-            }
-            if (!toolCallContext.containsKey(initialRequestId)) {
-                toolCallContext.put(initialRequestId, messages);
-            }
-            List<ChatMessage> messageContext =
-                    new ArrayList<>((List<ChatMessage>) 
toolCallContext.get(initialRequestId));
+            handleToolCalls(response, initialRequestId, model, messages, ctx);
+        } else {
+            // clean tool call context
+            clearToolCallContext(ctx.getSensoryMemory(), initialRequestId);
 
-            messageContext.add(response);
-            toolCallContext.put(initialRequestId, messageContext);
-            sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
+            ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
+        }
+    }
 
-            ToolRequestEvent toolRequestEvent =
-                    new ToolRequestEvent(model, response.getToolCalls());
+    private static void processChatRequest(ChatRequestEvent event, 
RunnerContext ctx)
+            throws Exception {
+        chat(event.getId(), event.getModel(), event.getMessages(), ctx);
+    }
 
-            Map<UUID, Object> toolRequestEventContext;
-            if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) {
-                toolRequestEventContext =
-                        (Map<UUID, Object>) 
sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
-            } else {
-                toolRequestEventContext = new HashMap<>();
-            }
-            toolRequestEventContext.put(
-                    toolRequestEvent.getId(),
-                    Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, 
model));
-            sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, 
toolRequestEventContext);
+    private static void processToolResponse(ToolResponseEvent event, 
RunnerContext ctx)
+            throws Exception {
+        MemoryObject sensoryMem = ctx.getSensoryMemory();
 
-            ctx.sendEvent(toolRequestEvent);
-        } else {
-            // clean tool call context
-            if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) {
-                Map<UUID, Object> toolCallContext =
-                        (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
-                if (toolCallContext.containsKey(initialRequestId)) {
-                    toolCallContext.remove(initialRequestId);
-                    sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
-                }
+        // get tool request context from memory
+        Map<String, Object> context =
+                removeToolRequestEventContext(sensoryMem, 
event.getRequestId());
+
+        UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID);
+        String model = (String) context.get(MODEL);
+
+        Map<String, ToolResponse> responses = event.getResponses();
+        Map<String, Boolean> success = event.getSuccess();
+
+        List<ChatMessage> toolResponseMessages = new ArrayList<>();
+
+        for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) {
+            Map<String, Object> extraArgs = new HashMap<>();
+            String toolCallId = entry.getKey();
+            if (event.getExternalIds().containsKey(toolCallId)) {
+                extraArgs.put("externalId", 
event.getExternalIds().get(toolCallId));
             }
 
-            ctx.sendEvent(new ChatResponseEvent(initialRequestId, response));
+            ToolResponse response = entry.getValue();
+            if (success.get(toolCallId) && response.isSuccess()) {
+                toolResponseMessages.add(
+                        new ChatMessage(
+                                MessageRole.TOOL, 
String.valueOf(response.getResult()), extraArgs));
+            } else {
+                toolResponseMessages.add(
+                        new ChatMessage(
+                                MessageRole.TOOL, 
String.valueOf(response.getError()), extraArgs));
+            }
         }
+
+        List<ChatMessage> messages =
+                updateToolCallContext(
+                        ctx.getSensoryMemory(),
+                        initialRequestId,
+                        Collections.emptyList(),
+                        toolResponseMessages);
+
+        chat(initialRequestId, model, messages, ctx);
     }
 
     /**
@@ -128,66 +226,13 @@ public class ChatModelAction {
      *     ToolResponseEvent}
      * @param ctx The runner context this action executed in.
      */
-    @SuppressWarnings("unchecked")
     public static void processChatRequestOrToolResponse(Event event, 
RunnerContext ctx)
             throws Exception {
         MemoryObject sensoryMem = ctx.getSensoryMemory();
         if (event instanceof ChatRequestEvent) {
-            ChatRequestEvent chatRequestEvent = (ChatRequestEvent) event;
-            chat(
-                    chatRequestEvent.getId(),
-                    chatRequestEvent.getModel(),
-                    chatRequestEvent.getMessages(),
-                    ctx);
+            processChatRequest((ChatRequestEvent) event, ctx);
         } else if (event instanceof ToolResponseEvent) {
-            ToolResponseEvent toolResponseEvent = (ToolResponseEvent) event;
-            UUID toolRequestId = toolResponseEvent.getRequestId();
-            // get tool request context from memory
-            Map<UUID, Object> toolRequestEventContext =
-                    (Map<UUID, Object>) 
sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue();
-            Map<String, Object> context =
-                    (Map<String, Object>) 
toolRequestEventContext.get(toolRequestId);
-            UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID);
-            String model = (String) context.get(MODEL);
-            toolRequestEventContext.remove(toolRequestId);
-            sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, 
toolRequestEventContext);
-            Map<String, ToolResponse> responses = 
toolResponseEvent.getResponses();
-            Map<String, Boolean> success = toolResponseEvent.getSuccess();
-
-            // get tool call context
-            Map<UUID, Object> toolCallContext =
-                    (Map<UUID, Object>) 
sensoryMem.get(TOOL_CALL_CONTEXT).getValue();
-            // update tool call context
-            List<ChatMessage> messages =
-                    new ArrayList<>((List<ChatMessage>) 
toolCallContext.get(initialRequestId));
-
-            for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) 
{
-                Map<String, Object> extraArgs = new HashMap<>();
-                String toolCallId = entry.getKey();
-                if 
(toolResponseEvent.getExternalIds().containsKey(toolCallId)) {
-                    extraArgs.put("externalId", 
toolResponseEvent.getExternalIds().get(toolCallId));
-                }
-
-                ToolResponse response = entry.getValue();
-                if (success.get(toolCallId) && response.isSuccess()) {
-                    messages.add(
-                            new ChatMessage(
-                                    MessageRole.TOOL,
-                                    String.valueOf(response.getResult()),
-                                    extraArgs));
-                } else {
-                    messages.add(
-                            new ChatMessage(
-                                    MessageRole.TOOL,
-                                    String.valueOf(response.getError()),
-                                    extraArgs));
-                }
-            }
-            toolCallContext.put(initialRequestId, messages);
-            // overwrite tool call context
-            sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext);
-
-            chat(initialRequestId, model, messages, ctx);
+            processToolResponse((ToolResponseEvent) event, ctx);
         } else {
             throw new RuntimeException(String.format("Unexpected type event 
%s", event));
         }
diff --git a/python/flink_agents/plan/actions/chat_model_action.py 
b/python/flink_agents/plan/actions/chat_model_action.py
index 6db3908..7299f13 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -16,13 +16,14 @@
 # limitations under the License.
 
#################################################################################
 import copy
-from typing import TYPE_CHECKING, List, cast
+from typing import TYPE_CHECKING, Dict, List, cast
 from uuid import UUID
 
 from flink_agents.api.chat_message import ChatMessage, MessageRole
 from flink_agents.api.events.chat_event import ChatRequestEvent, 
ChatResponseEvent
 from flink_agents.api.events.event import Event
 from flink_agents.api.events.tool_event import ToolRequestEvent, 
ToolResponseEvent
+from flink_agents.api.memory_object import MemoryObject
 from flink_agents.api.resource import ResourceType
 from flink_agents.api.runner_context import RunnerContext
 from flink_agents.plan.actions.action import Action
@@ -35,6 +36,100 @@ _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT"
 _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT"
 
 
+# ============================================================================
+# Helper Functions for Tool Call Context Management
+# ============================================================================
+def _update_tool_call_context(
+    sensory_memory: MemoryObject,
+    initial_request_id: UUID,
+    initial_messages: List[ChatMessage] | None,
+    added_messages: List[ChatMessage],
+) -> List[ChatMessage]:
+    """Append messages to tool call context.
+
+    The messages maybe chat model response with tool calls, or tool execute 
results. May
+    initialize the context for initial_request_id if needed.
+    """
+    # TODO: Because memory doesn't support remove currently, so we use
+    #  dict to store tool context in memory and remove the specific
+    #  tool context from dict after consuming. This will cause write and
+    #  read amplification for we need get the whole dict and overwrite it
+    #  to memory each time we update a specific tool context.
+    #  After memory supports remove, we can use "TOOL_CALL_CONTEXT/request_id"
+    #  to store and remove the specific tool context directly.
+
+    # init if not exists
+    tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
+    if initial_request_id not in tool_call_context and initial_messages is not 
None:
+        tool_call_context[initial_request_id] = copy.deepcopy(initial_messages)
+
+    tool_call_context[initial_request_id].extend(added_messages)
+
+    # update tool call context
+    sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
+    return tool_call_context[initial_request_id]
+
+
+def _clear_tool_call_context(
+    sensory_memory: MemoryObject, initial_request_id: UUID
+) -> None:
+    """Clear tool call context for a specific request ID."""
+    context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
+    if initial_request_id in context:
+        context.pop(initial_request_id)
+        sensory_memory.set(_TOOL_CALL_CONTEXT, context)
+
+
+def _save_tool_request_event_context(
+    sensory_memory: MemoryObject,
+    tool_request_event_id: UUID,
+    initial_request_id: UUID,
+    model: str,
+) -> None:
+    """Save the context for a specific tool request event."""
+    context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
+    context[tool_request_event_id] = {
+        "initial_request_id": initial_request_id,
+        "model": model,
+    }
+    sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context)
+
+
+def _remove_tool_request_event_context(
+    sensory_memory: MemoryObject, request_id: UUID
+) -> Dict:
+    """Get and remove the context for a specific tool request event."""
+    context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
+    removed_context = context.pop(request_id, {})
+    sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, removed_context)
+    return removed_context
+
+
+def _handle_tool_calls(
+    response: ChatMessage,
+    initial_request_id: UUID,
+    model: str,
+    messages: List[ChatMessage],
+    ctx: RunnerContext,
+) -> None:
+    """Handle tool calls in chat response."""
+    _update_tool_call_context(
+        ctx.sensory_memory, initial_request_id, messages, [response]
+    )
+
+    tool_request_event = ToolRequestEvent(
+        model=model,
+        tool_calls=response.tool_calls,
+    )
+
+    # save tool request event context
+    _save_tool_request_event_context(
+        ctx.sensory_memory, tool_request_event.id, initial_request_id, model
+    )
+
+    ctx.send_event(tool_request_event)
+
+
 def chat(
     initial_request_id: UUID,
     model: str,
@@ -53,52 +148,14 @@ def chat(
 
     # TODO: support async execution of chat.
     response = chat_model.chat(messages)
-    sensory_memory = ctx.sensory_memory
 
-    # generate tool request event according tool calls in response
-    if len(response.tool_calls) > 0:
-        # TODO: Because memory doesn't support remove currently, so we use
-        #  dict to store tool context in memory and remove the specific
-        #  tool context from dict after consuming. This will cause write and
-        #  read amplification for we need get the whole dict and overwrite it
-        #  to memory each time we update a specific tool context.
-        #  After memory supports remove, we can use 
"TOOL_CALL_CONTEXT/request_id"
-        #  to store and remove the specific tool context directly.
-
-        # save tool call context
-        tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT)
-        if not tool_call_context:
-            tool_call_context = {}
-        if initial_request_id not in tool_call_context:
-            tool_call_context[initial_request_id] = copy.deepcopy(messages)
-        # append response to tool call context
-        tool_call_context[initial_request_id].append(response)
-        # update tool call context
-        sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
-
-        tool_request_event = ToolRequestEvent(
-            model=model,
-            tool_calls=response.tool_calls,
-        )
+    if (
+        len(response.tool_calls) > 0
+    ):  # generate tool request event according tool calls in response
+        _handle_tool_calls(response, initial_request_id, model, messages, ctx)
+    else:  # if there is no tool call generated, return chat response directly
+        _clear_tool_call_context(ctx.sensory_memory, initial_request_id)
 
-        # save tool request event context
-        tool_request_event_context = 
tool_call_context.get(_TOOL_REQUEST_EVENT_CONTEXT)
-        if not tool_request_event_context:
-            tool_request_event_context = {}
-        tool_request_event_context[tool_request_event.id] = {
-            "initial_request_id": initial_request_id,
-            "model": model,
-        }
-        sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, 
tool_request_event_context)
-
-        ctx.send_event(tool_request_event)
-    # if there is no tool call generated, return chat response directly
-    else:
-        # clear tool call context related to specific request id
-        tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT)
-        if tool_call_context and initial_request_id in tool_call_context:
-            tool_call_context.pop(initial_request_id)
-            sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
         ctx.send_event(
             ChatResponseEvent(
                 request_id=initial_request_id,
@@ -107,55 +164,63 @@ def chat(
         )
 
 
+def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None:
+    """Process chat request event."""
+    chat(
+        initial_request_id=event.id,
+        model=event.model,
+        messages=event.messages,
+        ctx=ctx,
+    )
+
+
+def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> 
None:
+    """Organize the tool call context and return it to the LLM."""
+    sensory_memory = ctx.sensory_memory
+    request_id = event.request_id
+
+    # get correspond tool request event context
+    tool_request_event_context = _remove_tool_request_event_context(
+        sensory_memory, request_id
+    )
+    initial_request_id = tool_request_event_context["initial_request_id"]
+
+    # update tool call context, and get the entire chat messages.
+    messages = _update_tool_call_context(
+        sensory_memory,
+        initial_request_id,
+        None,
+        [
+            ChatMessage(
+                role=MessageRole.TOOL,
+                content=str(response),
+                extra_args={"external_id": event.external_ids.get(tool_id)}
+                if event.external_ids and event.external_ids.get(tool_id)
+                else {},
+            )
+            for tool_id, response in event.responses.items()
+        ],
+    )
+
+    chat(
+        initial_request_id=initial_request_id,
+        model=tool_request_event_context["model"],
+        messages=messages,
+        ctx=ctx,
+    )
+
+
 def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> 
None:
     """Built-in action for processing a chat request or tool response.
 
-    Internally, this action will use short term memory to save the tool call 
context,
-    which is a dict mapping request id to chat messages.
+    This action listens to ChatRequestEvent and ToolResponseEvent, and handles
+    the complete chat flow including tool calls. It uses sensory memory to save
+    the tool call context, which is a dict mapping request id to chat messages.
     """
-    sensory_memory = ctx.sensory_memory
     if isinstance(event, ChatRequestEvent):
-        chat(
-            initial_request_id=event.id,
-            model=event.model,
-            messages=event.messages,
-            ctx=ctx,
-        )
-
+        _process_chat_request(event, ctx)
     elif isinstance(event, ToolResponseEvent):
-        request_id = event.request_id
-
-        # get correspond tool request event context
-        tool_request_event_context = 
sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT)
-        initial_request_id = tool_request_event_context[request_id][
-            "initial_request_id"
-        ]
-        model = tool_request_event_context[request_id]["model"]
-        # clear tool request event context
-        tool_request_event_context.pop(request_id)
-        sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, 
tool_request_event_context)
-
-        responses = event.responses
-        # update tool call context
-        tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT)
-        for id, response in responses.items():
-            tool_call_context[initial_request_id].append(
-                ChatMessage(
-                    role=MessageRole.TOOL,
-                    content=str(response),
-                    extra_args={"external_id": event.external_ids[id]}
-                    if event.external_ids[id]
-                    else {},
-                )
-            )
-        sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
-
-        chat(
-            initial_request_id=initial_request_id,
-            model=model,
-            messages=tool_call_context[initial_request_id],
-            ctx=ctx,
-        )
+        _process_tool_response(event, ctx)
 
 
 CHAT_MODEL_ACTION = Action(

Reply via email to