This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 51c150206cbce38f45ba32cf89a0ce5cf9857d5d Author: WenjinXie <[email protected]> AuthorDate: Thu Dec 25 15:58:37 2025 +0800 [plan] Refactor built-in chat action. fix [plan] Refactor built-in chat action in java. --- .../flink/agents/plan/actions/ChatModelAction.java | 233 ++++++++++++-------- .../flink_agents/plan/actions/chat_model_action.py | 241 +++++++++++++-------- 2 files changed, 292 insertions(+), 182 deletions(-) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java index 2b153ae..0fb96e3 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java @@ -50,6 +50,93 @@ public class ChatModelAction { List.of(ChatRequestEvent.class.getName(), ToolResponseEvent.class.getName())); } + @SuppressWarnings("unchecked") + private static List<ChatMessage> updateToolCallContext( + MemoryObject sensoryMem, + UUID initialRequestId, + List<ChatMessage> initialMessages, + List<ChatMessage> addedMessages) + throws Exception { + + Map<UUID, Object> toolCallContext; + if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) { + toolCallContext = (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); + } else { + toolCallContext = new HashMap<>(); + } + if (!toolCallContext.containsKey(initialRequestId)) { + toolCallContext.put(initialRequestId, initialMessages); + } + List<ChatMessage> messageContext = + new ArrayList<>((List<ChatMessage>) toolCallContext.get(initialRequestId)); + + messageContext.addAll(addedMessages); + toolCallContext.put(initialRequestId, messageContext); + sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); + return messageContext; + } + + @SuppressWarnings("unchecked") + private static void clearToolCallContext(MemoryObject sensoryMem, UUID initialRequestId) + throws Exception { + if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) { + Map<UUID, Object> toolCallContext = + (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); + if (toolCallContext.containsKey(initialRequestId)) { + toolCallContext.remove(initialRequestId); + sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); + } + } + } + + @SuppressWarnings("unchecked") + private static void saveToolRequestEventContext( + MemoryObject sensoryMem, UUID toolRequestEventId, UUID initialRequestId, String model) + throws Exception { + Map<UUID, Object> toolRequestEventContext; + if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) { + toolRequestEventContext = + (Map<UUID, Object>) sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); + } else { + toolRequestEventContext = new HashMap<>(); + } + toolRequestEventContext.put( + toolRequestEventId, Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, model)); + sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); + } + + @SuppressWarnings("unchecked") + private static Map<String, Object> removeToolRequestEventContext( + MemoryObject sensoryMem, UUID requestId) throws Exception { + Map<UUID, Object> toolRequestEventContext = + (Map<UUID, Object>) sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); + Map<String, Object> context = + (Map<String, Object>) toolRequestEventContext.remove(requestId); + sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); + return context; + } + + private static void handleToolCalls( + ChatMessage response, + UUID initialRequestId, + String model, + List<ChatMessage> messages, + RunnerContext ctx) + throws Exception { + updateToolCallContext( + ctx.getSensoryMemory(), + initialRequestId, + messages, + Collections.singletonList(response)); + + ToolRequestEvent toolRequestEvent = new ToolRequestEvent(model, response.getToolCalls()); + + saveToolRequestEventContext( + ctx.getSensoryMemory(), toolRequestEvent.getId(), initialRequestId, model); + + ctx.sendEvent(toolRequestEvent); + } + /** * Chat with chat model. * @@ -67,54 +154,65 @@ public class ChatModelAction { (BaseChatModelSetup) ctx.getResource(model, ResourceType.CHAT_MODEL); ChatMessage response = chatModel.chat(messages, Map.of()); - MemoryObject sensoryMem = ctx.getSensoryMemory(); if (!response.getToolCalls().isEmpty()) { - Map<UUID, Object> toolCallContext; - if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) { - toolCallContext = (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); - } else { - toolCallContext = new HashMap<>(); - } - if (!toolCallContext.containsKey(initialRequestId)) { - toolCallContext.put(initialRequestId, messages); - } - List<ChatMessage> messageContext = - new ArrayList<>((List<ChatMessage>) toolCallContext.get(initialRequestId)); + handleToolCalls(response, initialRequestId, model, messages, ctx); + } else { + // clean tool call context + clearToolCallContext(ctx.getSensoryMemory(), initialRequestId); - messageContext.add(response); - toolCallContext.put(initialRequestId, messageContext); - sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); + ctx.sendEvent(new ChatResponseEvent(initialRequestId, response)); + } + } - ToolRequestEvent toolRequestEvent = - new ToolRequestEvent(model, response.getToolCalls()); + private static void processChatRequest(ChatRequestEvent event, RunnerContext ctx) + throws Exception { + chat(event.getId(), event.getModel(), event.getMessages(), ctx); + } - Map<UUID, Object> toolRequestEventContext; - if (sensoryMem.isExist(TOOL_REQUEST_EVENT_CONTEXT)) { - toolRequestEventContext = - (Map<UUID, Object>) sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); - } else { - toolRequestEventContext = new HashMap<>(); - } - toolRequestEventContext.put( - toolRequestEvent.getId(), - Map.of(INITIAL_REQUEST_ID, initialRequestId, MODEL, model)); - sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); + private static void processToolResponse(ToolResponseEvent event, RunnerContext ctx) + throws Exception { + MemoryObject sensoryMem = ctx.getSensoryMemory(); - ctx.sendEvent(toolRequestEvent); - } else { - // clean tool call context - if (sensoryMem.isExist(TOOL_CALL_CONTEXT)) { - Map<UUID, Object> toolCallContext = - (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); - if (toolCallContext.containsKey(initialRequestId)) { - toolCallContext.remove(initialRequestId); - sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); - } + // get tool request context from memory + Map<String, Object> context = + removeToolRequestEventContext(sensoryMem, event.getRequestId()); + + UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID); + String model = (String) context.get(MODEL); + + Map<String, ToolResponse> responses = event.getResponses(); + Map<String, Boolean> success = event.getSuccess(); + + List<ChatMessage> toolResponseMessages = new ArrayList<>(); + + for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) { + Map<String, Object> extraArgs = new HashMap<>(); + String toolCallId = entry.getKey(); + if (event.getExternalIds().containsKey(toolCallId)) { + extraArgs.put("externalId", event.getExternalIds().get(toolCallId)); } - ctx.sendEvent(new ChatResponseEvent(initialRequestId, response)); + ToolResponse response = entry.getValue(); + if (success.get(toolCallId) && response.isSuccess()) { + toolResponseMessages.add( + new ChatMessage( + MessageRole.TOOL, String.valueOf(response.getResult()), extraArgs)); + } else { + toolResponseMessages.add( + new ChatMessage( + MessageRole.TOOL, String.valueOf(response.getError()), extraArgs)); + } } + + List<ChatMessage> messages = + updateToolCallContext( + ctx.getSensoryMemory(), + initialRequestId, + Collections.emptyList(), + toolResponseMessages); + + chat(initialRequestId, model, messages, ctx); } /** @@ -128,66 +226,13 @@ public class ChatModelAction { * ToolResponseEvent} * @param ctx The runner context this action executed in. */ - @SuppressWarnings("unchecked") public static void processChatRequestOrToolResponse(Event event, RunnerContext ctx) throws Exception { MemoryObject sensoryMem = ctx.getSensoryMemory(); if (event instanceof ChatRequestEvent) { - ChatRequestEvent chatRequestEvent = (ChatRequestEvent) event; - chat( - chatRequestEvent.getId(), - chatRequestEvent.getModel(), - chatRequestEvent.getMessages(), - ctx); + processChatRequest((ChatRequestEvent) event, ctx); } else if (event instanceof ToolResponseEvent) { - ToolResponseEvent toolResponseEvent = (ToolResponseEvent) event; - UUID toolRequestId = toolResponseEvent.getRequestId(); - // get tool request context from memory - Map<UUID, Object> toolRequestEventContext = - (Map<UUID, Object>) sensoryMem.get(TOOL_REQUEST_EVENT_CONTEXT).getValue(); - Map<String, Object> context = - (Map<String, Object>) toolRequestEventContext.get(toolRequestId); - UUID initialRequestId = (UUID) context.get(INITIAL_REQUEST_ID); - String model = (String) context.get(MODEL); - toolRequestEventContext.remove(toolRequestId); - sensoryMem.set(TOOL_REQUEST_EVENT_CONTEXT, toolRequestEventContext); - Map<String, ToolResponse> responses = toolResponseEvent.getResponses(); - Map<String, Boolean> success = toolResponseEvent.getSuccess(); - - // get tool call context - Map<UUID, Object> toolCallContext = - (Map<UUID, Object>) sensoryMem.get(TOOL_CALL_CONTEXT).getValue(); - // update tool call context - List<ChatMessage> messages = - new ArrayList<>((List<ChatMessage>) toolCallContext.get(initialRequestId)); - - for (Map.Entry<String, ToolResponse> entry : responses.entrySet()) { - Map<String, Object> extraArgs = new HashMap<>(); - String toolCallId = entry.getKey(); - if (toolResponseEvent.getExternalIds().containsKey(toolCallId)) { - extraArgs.put("externalId", toolResponseEvent.getExternalIds().get(toolCallId)); - } - - ToolResponse response = entry.getValue(); - if (success.get(toolCallId) && response.isSuccess()) { - messages.add( - new ChatMessage( - MessageRole.TOOL, - String.valueOf(response.getResult()), - extraArgs)); - } else { - messages.add( - new ChatMessage( - MessageRole.TOOL, - String.valueOf(response.getError()), - extraArgs)); - } - } - toolCallContext.put(initialRequestId, messages); - // overwrite tool call context - sensoryMem.set(TOOL_CALL_CONTEXT, toolCallContext); - - chat(initialRequestId, model, messages, ctx); + processToolResponse((ToolResponseEvent) event, ctx); } else { throw new RuntimeException(String.format("Unexpected type event %s", event)); } diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 6db3908..7299f13 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -16,13 +16,14 @@ # limitations under the License. ################################################################################# import copy -from typing import TYPE_CHECKING, List, cast +from typing import TYPE_CHECKING, Dict, List, cast from uuid import UUID from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent from flink_agents.api.events.event import Event from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent +from flink_agents.api.memory_object import MemoryObject from flink_agents.api.resource import ResourceType from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.actions.action import Action @@ -35,6 +36,100 @@ _TOOL_CALL_CONTEXT = "_TOOL_CALL_CONTEXT" _TOOL_REQUEST_EVENT_CONTEXT = "_TOOL_REQUEST_EVENT_CONTEXT" +# ============================================================================ +# Helper Functions for Tool Call Context Management +# ============================================================================ +def _update_tool_call_context( + sensory_memory: MemoryObject, + initial_request_id: UUID, + initial_messages: List[ChatMessage] | None, + added_messages: List[ChatMessage], +) -> List[ChatMessage]: + """Append messages to tool call context. + + The messages maybe chat model response with tool calls, or tool execute results. May + initialize the context for initial_request_id if needed. + """ + # TODO: Because memory doesn't support remove currently, so we use + # dict to store tool context in memory and remove the specific + # tool context from dict after consuming. This will cause write and + # read amplification for we need get the whole dict and overwrite it + # to memory each time we update a specific tool context. + # After memory supports remove, we can use "TOOL_CALL_CONTEXT/request_id" + # to store and remove the specific tool context directly. + + # init if not exists + tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {} + if initial_request_id not in tool_call_context and initial_messages is not None: + tool_call_context[initial_request_id] = copy.deepcopy(initial_messages) + + tool_call_context[initial_request_id].extend(added_messages) + + # update tool call context + sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) + return tool_call_context[initial_request_id] + + +def _clear_tool_call_context( + sensory_memory: MemoryObject, initial_request_id: UUID +) -> None: + """Clear tool call context for a specific request ID.""" + context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {} + if initial_request_id in context: + context.pop(initial_request_id) + sensory_memory.set(_TOOL_CALL_CONTEXT, context) + + +def _save_tool_request_event_context( + sensory_memory: MemoryObject, + tool_request_event_id: UUID, + initial_request_id: UUID, + model: str, +) -> None: + """Save the context for a specific tool request event.""" + context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} + context[tool_request_event_id] = { + "initial_request_id": initial_request_id, + "model": model, + } + sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context) + + +def _remove_tool_request_event_context( + sensory_memory: MemoryObject, request_id: UUID +) -> Dict: + """Get and remove the context for a specific tool request event.""" + context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} + removed_context = context.pop(request_id, {}) + sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, removed_context) + return removed_context + + +def _handle_tool_calls( + response: ChatMessage, + initial_request_id: UUID, + model: str, + messages: List[ChatMessage], + ctx: RunnerContext, +) -> None: + """Handle tool calls in chat response.""" + _update_tool_call_context( + ctx.sensory_memory, initial_request_id, messages, [response] + ) + + tool_request_event = ToolRequestEvent( + model=model, + tool_calls=response.tool_calls, + ) + + # save tool request event context + _save_tool_request_event_context( + ctx.sensory_memory, tool_request_event.id, initial_request_id, model + ) + + ctx.send_event(tool_request_event) + + def chat( initial_request_id: UUID, model: str, @@ -53,52 +148,14 @@ def chat( # TODO: support async execution of chat. response = chat_model.chat(messages) - sensory_memory = ctx.sensory_memory - # generate tool request event according tool calls in response - if len(response.tool_calls) > 0: - # TODO: Because memory doesn't support remove currently, so we use - # dict to store tool context in memory and remove the specific - # tool context from dict after consuming. This will cause write and - # read amplification for we need get the whole dict and overwrite it - # to memory each time we update a specific tool context. - # After memory supports remove, we can use "TOOL_CALL_CONTEXT/request_id" - # to store and remove the specific tool context directly. - - # save tool call context - tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) - if not tool_call_context: - tool_call_context = {} - if initial_request_id not in tool_call_context: - tool_call_context[initial_request_id] = copy.deepcopy(messages) - # append response to tool call context - tool_call_context[initial_request_id].append(response) - # update tool call context - sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) - - tool_request_event = ToolRequestEvent( - model=model, - tool_calls=response.tool_calls, - ) + if ( + len(response.tool_calls) > 0 + ): # generate tool request event according tool calls in response + _handle_tool_calls(response, initial_request_id, model, messages, ctx) + else: # if there is no tool call generated, return chat response directly + _clear_tool_call_context(ctx.sensory_memory, initial_request_id) - # save tool request event context - tool_request_event_context = tool_call_context.get(_TOOL_REQUEST_EVENT_CONTEXT) - if not tool_request_event_context: - tool_request_event_context = {} - tool_request_event_context[tool_request_event.id] = { - "initial_request_id": initial_request_id, - "model": model, - } - sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, tool_request_event_context) - - ctx.send_event(tool_request_event) - # if there is no tool call generated, return chat response directly - else: - # clear tool call context related to specific request id - tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) - if tool_call_context and initial_request_id in tool_call_context: - tool_call_context.pop(initial_request_id) - sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) ctx.send_event( ChatResponseEvent( request_id=initial_request_id, @@ -107,55 +164,63 @@ def chat( ) +def _process_chat_request(event: ChatRequestEvent, ctx: RunnerContext) -> None: + """Process chat request event.""" + chat( + initial_request_id=event.id, + model=event.model, + messages=event.messages, + ctx=ctx, + ) + + +def _process_tool_response(event: ToolResponseEvent, ctx: RunnerContext) -> None: + """Organize the tool call context and return it to the LLM.""" + sensory_memory = ctx.sensory_memory + request_id = event.request_id + + # get correspond tool request event context + tool_request_event_context = _remove_tool_request_event_context( + sensory_memory, request_id + ) + initial_request_id = tool_request_event_context["initial_request_id"] + + # update tool call context, and get the entire chat messages. + messages = _update_tool_call_context( + sensory_memory, + initial_request_id, + None, + [ + ChatMessage( + role=MessageRole.TOOL, + content=str(response), + extra_args={"external_id": event.external_ids.get(tool_id)} + if event.external_ids and event.external_ids.get(tool_id) + else {}, + ) + for tool_id, response in event.responses.items() + ], + ) + + chat( + initial_request_id=initial_request_id, + model=tool_request_event_context["model"], + messages=messages, + ctx=ctx, + ) + + def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None: """Built-in action for processing a chat request or tool response. - Internally, this action will use short term memory to save the tool call context, - which is a dict mapping request id to chat messages. + This action listens to ChatRequestEvent and ToolResponseEvent, and handles + the complete chat flow including tool calls. It uses sensory memory to save + the tool call context, which is a dict mapping request id to chat messages. """ - sensory_memory = ctx.sensory_memory if isinstance(event, ChatRequestEvent): - chat( - initial_request_id=event.id, - model=event.model, - messages=event.messages, - ctx=ctx, - ) - + _process_chat_request(event, ctx) elif isinstance(event, ToolResponseEvent): - request_id = event.request_id - - # get correspond tool request event context - tool_request_event_context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) - initial_request_id = tool_request_event_context[request_id][ - "initial_request_id" - ] - model = tool_request_event_context[request_id]["model"] - # clear tool request event context - tool_request_event_context.pop(request_id) - sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, tool_request_event_context) - - responses = event.responses - # update tool call context - tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) - for id, response in responses.items(): - tool_call_context[initial_request_id].append( - ChatMessage( - role=MessageRole.TOOL, - content=str(response), - extra_args={"external_id": event.external_ids[id]} - if event.external_ids[id] - else {}, - ) - ) - sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) - - chat( - initial_request_id=initial_request_id, - model=model, - messages=tool_call_context[initial_request_id], - ctx=ctx, - ) + _process_tool_response(event, ctx) CHAT_MODEL_ACTION = Action(
