This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit cc46f2986edb783c1ae3cbd836e65334a9db96c0 Author: WenjinXie <wenjin...@gmail.com> AuthorDate: Tue Aug 5 10:33:22 2025 +0800 [plan][python] Add built-in actions for processing chat and tool call. --- .../compatibility/CreateJavaAgentPlanFromJson.java | 55 ++++++- python/flink_agents/api/decorators.py | 27 +++- python/flink_agents/api/events/chat_event.py | 51 +++++++ python/flink_agents/api/events/tool_event.py | 50 ++++++ python/flink_agents/api/resource.py | 15 +- .../flink_agents/plan/actions/chat_model_action.py | 90 +++++++++++ .../flink_agents/plan/actions/tool_call_action.py | 37 +++++ python/flink_agents/plan/agent_plan.py | 42 ++--- python/flink_agents/plan/resource_provider.py | 36 ++++- .../plan/tests/resources/agent_plan.json | 46 +++++- .../runtime/tests/test_built_in_actions.py | 169 +++++++++++++++++++++ 11 files changed, 571 insertions(+), 47 deletions(-) diff --git a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java index d6ed2e5..54af88d 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/compatibility/CreateJavaAgentPlanFromJson.java @@ -27,6 +27,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.List; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; @@ -45,31 +46,73 @@ public class CreateJavaAgentPlanFromJson { String agentJsonFile = args[0]; String json = Files.readString(Paths.get(agentJsonFile)); AgentPlan agentPlan = new ObjectMapper().readValue(json, AgentPlan.class); - assertEquals(2, agentPlan.getActions().size()); + assertEquals(4, agentPlan.getActions().size()); String myEvent = "flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent.MyEvent"; - String inputEvent = "flink_agents.api.event.InputEvent"; + String inputEvent = "flink_agents.api.events.event.InputEvent"; // Check the first action + String testModule = + "flink_agents.plan.tests.compatibility.python_agent_plan_compatibility_test_agent"; assertTrue(agentPlan.getActions().containsKey("first_action")); Action firstAction = agentPlan.getActions().get("first_action"); assertInstanceOf(PythonFunction.class, firstAction.getExec()); + PythonFunction firstActionFunction = (PythonFunction) firstAction.getExec(); + assertEquals(testModule, firstActionFunction.getModule()); + assertEquals( + "PythonAgentPlanCompatibilityTestAgent.first_action", + firstActionFunction.getQualName()); assertEquals(List.of(inputEvent), firstAction.getListenEventTypes()); // Check the second action assertTrue(agentPlan.getActions().containsKey("second_action")); Action secondAction = agentPlan.getActions().get("second_action"); assertInstanceOf(PythonFunction.class, secondAction.getExec()); - + PythonFunction secondActionFunc = (PythonFunction) secondAction.getExec(); + assertEquals(testModule, secondActionFunc.getModule()); + assertEquals( + "PythonAgentPlanCompatibilityTestAgent.second_action", + secondActionFunc.getQualName()); assertEquals(List.of(inputEvent, myEvent), secondAction.getListenEventTypes()); + // Check the built-in actions + assertTrue(agentPlan.getActions().containsKey("chat_model_action")); + Action chatModelAction = agentPlan.getActions().get("chat_model_action"); + assertInstanceOf(PythonFunction.class, chatModelAction.getExec()); + PythonFunction processChatRequestFunc = (PythonFunction) chatModelAction.getExec(); + assertEquals( + "flink_agents.plan.actions.chat_model_action", processChatRequestFunc.getModule()); + assertEquals("process_chat_request_or_tool_response", processChatRequestFunc.getQualName()); + String chatRequestEvent = "flink_agents.api.events.chat_event.ChatRequestEvent"; + String toolResponseEvent = "flink_agents.api.events.tool_event.ToolResponseEvent"; + assertEquals( + List.of(chatRequestEvent, toolResponseEvent), + chatModelAction.getListenEventTypes()); + + assertTrue(agentPlan.getActions().containsKey("tool_call_action")); + Action toolCallAction = agentPlan.getActions().get("tool_call_action"); + assertInstanceOf(PythonFunction.class, toolCallAction.getExec()); + PythonFunction processToolRequestFunc = (PythonFunction) toolCallAction.getExec(); + assertEquals( + "flink_agents.plan.actions.tool_call_action", processToolRequestFunc.getModule()); + assertEquals("process_tool_request", processToolRequestFunc.getQualName()); + String toolRequestEvent = "flink_agents.api.events.tool_event.ToolRequestEvent"; + assertEquals(List.of(toolRequestEvent), toolCallAction.getListenEventTypes()); + // Check event trigger actions - assertEquals(2, agentPlan.getActionsByEvent().size()); - assertTrue(agentPlan.getActionsByEvent().containsKey(inputEvent)); - assertTrue(agentPlan.getActionsByEvent().containsKey(myEvent)); + Map<String, List<Action>> actionsByEvent = agentPlan.getActionsByEvent(); + assertEquals(5, actionsByEvent.size()); + assertTrue(actionsByEvent.containsKey(inputEvent)); + assertTrue(actionsByEvent.containsKey(myEvent)); + assertTrue(actionsByEvent.containsKey(chatRequestEvent)); + assertTrue(actionsByEvent.containsKey(toolRequestEvent)); + assertTrue(actionsByEvent.containsKey(toolResponseEvent)); assertEquals( List.of(firstAction, secondAction), agentPlan.getActionsByEvent().get(inputEvent)); assertEquals(List.of(secondAction), agentPlan.getActionsByEvent().get(myEvent)); + assertEquals(List.of(chatModelAction), actionsByEvent.get(chatRequestEvent)); + assertEquals(List.of(toolCallAction), actionsByEvent.get(toolRequestEvent)); + assertEquals(List.of(chatModelAction), actionsByEvent.get(toolResponseEvent)); } } diff --git a/python/flink_agents/api/decorators.py b/python/flink_agents/api/decorators.py index 1d8e3bb..8119f26 100644 --- a/python/flink_agents/api/decorators.py +++ b/python/flink_agents/api/decorators.py @@ -15,12 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Callable, Tuple, Type +from typing import Callable, Type -from flink_agents.api.event import Event +from flink_agents.api.events.event import Event -def action(*listen_events: Tuple[Type[Event], ...]) -> Callable: +def action(*listen_events: Type[Event]) -> Callable: """Decorator for marking a function as an agent action. Parameters @@ -70,7 +70,7 @@ def chat_model(func: Callable) -> Callable: def tool(func: Callable) -> Callable: - """Decorator for marking a function declaring a chat model. + """Decorator for marking a function declaring a tool. Parameters ---------- @@ -80,7 +80,24 @@ def tool(func: Callable) -> Callable: Returns: ------- Callable - Decorator function that marks the target function declare a tools. + Decorator function that marks the target function declare a tool. """ func._is_tool = True return func + + +def prompt(func: Callable) -> Callable: + """Decorator for marking a function declaring a prompt. + + Parameters + ---------- + func : Callable + Function to be decorated. + + Returns: + ------- + Callable + Decorator function that marks the target function declare a prompt. + """ + func._is_prompt = True + return func diff --git a/python/flink_agents/api/events/chat_event.py b/python/flink_agents/api/events/chat_event.py new file mode 100644 index 0000000..b1f53db --- /dev/null +++ b/python/flink_agents/api/events/chat_event.py @@ -0,0 +1,51 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import List + +from flink_agents.api.chat_message import ChatMessage +from flink_agents.api.events.event import Event + + +class ChatRequestEvent(Event): + """Event representing a request to chat model. + + Attributes: + ---------- + model : str + The name of the chat model to be chatted with. + messages : List[ChatMessage] + The input to the chat model. + """ + + model: str + messages: List[ChatMessage] + + +class ChatResponseEvent(Event): + """Event representing a response from chat model. + + Attributes: + ---------- + request : ChatRequestEvent + The correspond request of the response. + response : ChatMessage + The response from the chat model. + """ + + request: ChatRequestEvent + response: ChatMessage diff --git a/python/flink_agents/api/events/tool_event.py b/python/flink_agents/api/events/tool_event.py new file mode 100644 index 0000000..a5e5c47 --- /dev/null +++ b/python/flink_agents/api/events/tool_event.py @@ -0,0 +1,50 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing import Any + +from flink_agents.api.events.event import Event + + +class ToolRequestEvent(Event): + """Event representing a tool call request. + + Attributes: + ---------- + tool : str + The name of the tool to be called. + kwargs : dict + The arguments passed to the tool. + """ + + tool: str + kwargs: dict + + +class ToolResponseEvent(Event): + """Event representing a result from tool call. + + Attributes: + ---------- + request : ToolRequestEvent + The correspond request of the response. + response : Any + The response from the tool. + """ + + request: ToolRequestEvent + response: Any diff --git a/python/flink_agents/api/resource.py b/python/flink_agents/api/resource.py index 30b440d..f8f70f4 100644 --- a/python/flink_agents/api/resource.py +++ b/python/flink_agents/api/resource.py @@ -17,20 +17,21 @@ ################################################################################# from abc import ABC, abstractmethod from enum import Enum +from typing import Callable -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator class ResourceType(Enum): """Type enum of resource. - Currently, only support chat_models and tools. + Currently, only support chat_model, tool and prompt. """ CHAT_MODEL = "chat_model" TOOL = "tool" # EMBEDDING_MODEL = "embedding_model" - # PROMPT = "prompt" + PROMPT = "prompt" # VECTOR_STORE = "vector_store" # MCP_SERVER = "mcp_server" @@ -46,11 +47,15 @@ class Resource(BaseModel, ABC): ---------- name : str The name of the resource. - type : ResourceType - The type of the resource. + get_resource : Callable[[str, ResourceType], "Resource"] + Get other resource object declared in the same Agent. The first argument is + resource name and the second argument is resource type. """ name: str + get_resource: Callable[[str, ResourceType], "Resource"] = Field( + exclude=True, default=None + ) @classmethod @abstractmethod diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py new file mode 100644 index 0000000..03ef48e --- /dev/null +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -0,0 +1,90 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import Event +from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent +from flink_agents.api.resource import ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.plan.actions.action import Action +from flink_agents.plan.function import PythonFunction + + +def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None: + """Built-in action for processing a chat request or tool response.""" + if isinstance(event, ChatRequestEvent): + chat_model = ctx.get_resource(event.model, ResourceType.CHAT_MODEL) + # TODO: support async execution of chat. + response = chat_model.chat(event.messages) + # call tool + if len(response.tool_calls) > 0: + for tool_call in response.tool_calls: + # store the tool call context in short term memory + state = ctx.get_short_term_memory() + # TODO: Because memory doesn't support remove currently, so we use + # dict to store tool context in memory and remove the specific + # tool context from dict after consuming. This will cause some + # overhead for we need get the whole dict and overwrite it to memory + # each time we update a specific tool context. + # After memory supports remove, we can use + # "__tool_context/tool_call_id" to store and remove the specific tool + # context directly. + if not state.is_exist("__tool_context"): + state.set("__tool_context", {}) + tool_context = state.get("__tool_context") + tool_call_id = tool_call["id"] + tool_context[tool_call_id] = event + tool_context[tool_call_id].messages.append(response) + state.set("__tool_context", tool_context) + ctx.send_event( + ToolRequestEvent( + id=tool_call_id, + tool=tool_call["function"]["name"], + kwargs=tool_call["function"]["arguments"], + ) + ) + + # send response + else: + ctx.send_event(ChatResponseEvent(request=event, response=response)) + elif isinstance(event, ToolResponseEvent): + state = ctx.get_short_term_memory() + + if state.is_exist("__tool_context"): + tool_context = state.get("__tool_context") + tool_call_id = event.request.id + if tool_context is not None and tool_call_id in tool_context: + # get the specific tool call context from short term memory + specific_tool_ctx = tool_context.pop(tool_call_id) + specific_tool_ctx.messages.append( + ChatMessage(role=MessageRole.TOOL, content=str(event.response)) + ) + ctx.send_event(specific_tool_ctx) + # update short term memory to remove the specific tool call context + state.set("__tool_context", tool_context) + + +CHAT_MODEL_ACTION = Action( + name="chat_model_action", + exec=PythonFunction.from_callable(process_chat_request_or_tool_response), + listen_event_types=[ + f"{ChatRequestEvent.__module__}.{ChatRequestEvent.__name__}", + f"{ToolResponseEvent.__module__}.{ToolResponseEvent.__name__}", + ], +) diff --git a/python/flink_agents/plan/actions/tool_call_action.py b/python/flink_agents/plan/actions/tool_call_action.py new file mode 100644 index 0000000..70788ed --- /dev/null +++ b/python/flink_agents/plan/actions/tool_call_action.py @@ -0,0 +1,37 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from flink_agents.api.events.tool_event import ToolRequestEvent, ToolResponseEvent +from flink_agents.api.resource import ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.plan.actions.action import Action +from flink_agents.plan.function import PythonFunction + + +def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: + """Built-in action for processing a tool call request.""" + tool = ctx.get_resource(event.tool, ResourceType.TOOL) + # TODO: support async execution of tool call. + response = tool.call(**event.kwargs) + ctx.send_event(ToolResponseEvent(request=event, response=response)) + + +TOOL_CALL_ACTION = Action( + name="tool_call_action", + exec=PythonFunction.from_callable(process_tool_request), + listen_event_types=[f"{ToolRequestEvent.__module__}.{ToolRequestEvent.__name__}"], +) diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index cb77f09..c1582e9 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -21,7 +21,9 @@ from pydantic import BaseModel, field_serializer, model_validator from flink_agents.api.agent import Agent from flink_agents.api.resource import Resource, ResourceType -from flink_agents.plan.action import Action +from flink_agents.plan.actions.action import Action +from flink_agents.plan.actions.chat_model_action import CHAT_MODEL_ACTION +from flink_agents.plan.actions.tool_call_action import TOOL_CALL_ACTION from flink_agents.plan.function import PythonFunction from flink_agents.plan.resource_provider import ( JavaResourceProvider, @@ -30,7 +32,9 @@ from flink_agents.plan.resource_provider import ( PythonSerializableResourceProvider, ResourceProvider, ) -from flink_agents.plan.tools.function_tool import FunctionTool +from flink_agents.plan.tools.function_tool import from_callable + +BUILT_IN_ACTIONS = [CHAT_MODEL_ACTION, TOOL_CALL_ACTION] class AgentPlan(BaseModel): @@ -116,7 +120,7 @@ class AgentPlan(BaseModel): """Build a AgentPlan from user defined agent.""" actions = {} actions_by_event = {} - for action in _get_actions(agent): + for action in _get_actions(agent) + BUILT_IN_ACTIONS: assert action.name not in actions, f"Duplicate action name: {action.name}" actions[action.name] = action for event_type in action.listen_event_types: @@ -169,7 +173,8 @@ class AgentPlan(BaseModel): self.__resources[type] = {} if name not in self.__resources[type]: resource_provider = self.resource_providers[type][name] - self.__resources[type][name] = resource_provider.provide() + resource = resource_provider.provide(get_resource=self.get_resource) + self.__resources[type][name] = resource return self.__resources[type][name] @@ -222,30 +227,33 @@ def _get_resource_providers(agent: Agent) -> List[ResourceProvider]: if callable(value): clazz, kwargs = value() - module = clazz.__module__ provider = PythonResourceProvider( name=name, type=clazz.resource_type(), - module=module, + module=clazz.__module__, clazz=clazz.__name__, kwargs=kwargs, ) resource_providers.append(provider) - if hasattr(value, "_is_tool"): + elif hasattr(value, "_is_tool"): if isinstance(value, staticmethod): value = value.__func__ if callable(value): # TODO: support other tool type. - func = PythonFunction.from_callable(value) - tool = FunctionTool(name=name, func=func) - provider = PythonSerializableResourceProvider( - name=tool.name, - type=tool.resource_type(), - serialized=tool.model_dump(), - module=tool.__module__, - clazz=tool.__class__.__name__, - resource=tool, + tool = from_callable(name=name, func=value) + resource_providers.append( + PythonSerializableResourceProvider.from_resource( + name=name, resource=tool + ) ) - resource_providers.append(provider) + elif hasattr(value, "_is_prompt"): + if isinstance(value, staticmethod): + value = value.__func__ + prompt = value() + resource_providers.append( + PythonSerializableResourceProvider.from_resource( + name=name, resource=prompt + ) + ) return resource_providers diff --git a/python/flink_agents/plan/resource_provider.py b/python/flink_agents/plan/resource_provider.py index f26e04a..7b988da 100644 --- a/python/flink_agents/plan/resource_provider.py +++ b/python/flink_agents/plan/resource_provider.py @@ -17,6 +17,7 @@ ################################################################################# import importlib from abc import ABC, abstractmethod +from collections.abc import Callable from typing import Any, Dict, Optional from pydantic import BaseModel @@ -44,8 +45,14 @@ class ResourceProvider(BaseModel, ABC): type: ResourceType @abstractmethod - def provide(self) -> Resource: - """Create resource in runtime.""" + def provide(self, get_resource: Callable) -> Resource: + """Create resource in runtime. + + Parameters + ---------- + get_resource : Callable + The helper function to get other resource declared in the same Agent. + """ class SerializableResourceProvider(ResourceProvider, ABC): @@ -81,11 +88,12 @@ class PythonResourceProvider(ResourceProvider): clazz: str kwargs: Dict[str, Any] - def provide(self) -> Resource: + def provide(self, get_resource: Callable) -> Resource: """Create resource in runtime.""" module = importlib.import_module(self.module) cls = getattr(module, self.clazz) - return cls(**self.kwargs) + resource = cls(**self.kwargs, get_resource=get_resource) + return resource class PythonSerializableResourceProvider(SerializableResourceProvider): @@ -102,7 +110,21 @@ class PythonSerializableResourceProvider(SerializableResourceProvider): serialized: Dict[str, Any] resource: Optional[SerializableResource] = None - def provide(self) -> Resource: + @staticmethod + def from_resource( + name: str, resource: SerializableResource + ) -> "PythonSerializableResourceProvider": + """Create PythonSerializableResourceProvider from SerializableResource.""" + return PythonSerializableResourceProvider( + name=name, + type=resource.resource_type(), + serialized=resource.model_dump(), + module=resource.__module__, + clazz=resource.__class__.__name__, + resource=resource, + ) + + def provide(self, get_resource: Callable) -> Resource: """Get or deserialize resource in runtime.""" if self.resource is None: module = importlib.import_module(self.module) @@ -118,7 +140,7 @@ class JavaResourceProvider(ResourceProvider): Currently, this class only used for deserializing Java agent plan json """ - def provide(self) -> Resource: + def provide(self, get_resource: Callable) -> Resource: """Create resource in runtime.""" err_msg = ( "Currently, flink-agents doesn't support create resource " @@ -134,7 +156,7 @@ class JavaSerializableResourceProvider(SerializableResourceProvider): Currently, this class only used for deserializing Java agent plan json """ - def provide(self) -> Resource: + def provide(self, get_resource: Callable) -> Resource: """Get or deserialize resource in runtime.""" err_msg = ( "Currently, flink-agents doesn't support create resource " diff --git a/python/flink_agents/plan/tests/resources/agent_plan.json b/python/flink_agents/plan/tests/resources/agent_plan.json index abbc3a4..9659a60 100644 --- a/python/flink_agents/plan/tests/resources/agent_plan.json +++ b/python/flink_agents/plan/tests/resources/agent_plan.json @@ -3,34 +3,66 @@ "first_action": { "name": "first_action", "exec": { + "func_type": "PythonFunction", "module": "flink_agents.plan.tests.test_agent_plan", - "qualname": "MyAgent.first_action", - "func_type": "PythonFunction" + "qualname": "MyAgent.first_action" }, "listen_event_types": [ - "flink_agents.api.event.InputEvent" + "flink_agents.api.events.event.InputEvent" ] }, "second_action": { "name": "second_action", "exec": { + "func_type": "PythonFunction", "module": "flink_agents.plan.tests.test_agent_plan", - "qualname": "MyAgent.second_action", - "func_type": "PythonFunction" + "qualname": "MyAgent.second_action" }, "listen_event_types": [ - "flink_agents.api.event.InputEvent", + "flink_agents.api.events.event.InputEvent", "flink_agents.plan.tests.test_agent_plan.MyEvent" ] + }, + "chat_model_action": { + "name": "chat_model_action", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.actions.chat_model_action", + "qualname": "process_chat_request_or_tool_response" + }, + "listen_event_types": [ + "flink_agents.api.events.chat_event.ChatRequestEvent", + "flink_agents.api.events.tool_event.ToolResponseEvent" + ] + }, + "tool_call_action": { + "name": "tool_call_action", + "exec": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.actions.tool_call_action", + "qualname": "process_tool_request" + }, + "listen_event_types": [ + "flink_agents.api.events.tool_event.ToolRequestEvent" + ] } }, "actions_by_event": { - "flink_agents.api.event.InputEvent": [ + "flink_agents.api.events.event.InputEvent": [ "first_action", "second_action" ], "flink_agents.plan.tests.test_agent_plan.MyEvent": [ "second_action" + ], + "flink_agents.api.events.chat_event.ChatRequestEvent": [ + "chat_model_action" + ], + "flink_agents.api.events.tool_event.ToolResponseEvent": [ + "chat_model_action" + ], + "flink_agents.api.events.tool_event.ToolRequestEvent": [ + "tool_call_action" ] }, "resource_providers": { diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py new file mode 100644 index 0000000..f86dfff --- /dev/null +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -0,0 +1,169 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import uuid +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type + +from flink_agents.api.agent import Agent +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.chat_models.chat_model import BaseChatModel +from flink_agents.api.decorators import action, chat_model, prompt, tool +from flink_agents.api.events.chat_event import ChatRequestEvent, ChatResponseEvent +from flink_agents.api.events.event import ( + InputEvent, + OutputEvent, +) +from flink_agents.api.execution_environment import AgentsExecutionEnvironment +from flink_agents.api.prompts.prompt import Prompt +from flink_agents.api.resource import ResourceType +from flink_agents.api.runner_context import RunnerContext +from flink_agents.api.tools.tool import ToolMetadata, ToolType + + +class MockChatModel(BaseChatModel): + """Mock ChatModel for testing integrating prompt and tool.""" + + __tools: List[ToolMetadata] + + def __init__(self, /, **kwargs: Any) -> None: + """Init method of MockChatModel.""" + super().__init__(**kwargs) + # bind tools + if self.tools is not None: + tools = [ + self.get_resource(tool_name, ResourceType.TOOL) + for tool_name in self.tools + ] + self.__tools = [tool.metadata for tool in tools] + # bind prompt + if self.prompt is not None and isinstance(self.prompt, str): + self.prompt = self.get_resource(self.prompt, ResourceType.PROMPT) + + def chat( + self, + messages: Sequence[ChatMessage], + chat_history: Optional[List[ChatMessage]] = None, + ) -> ChatMessage: + """Generate tool call or response according to input.""" + # generate tool call + if "sum" in messages[-1].content: + input = self.prompt.format_string(**messages[-1].extra_args) + # validate bind_tools + assert self.__tools[0].name == "add" + function = {"name": "add", "arguments": {"a": 1, "b": 2}} + tool_call = { + "id": uuid.uuid4(), + "type": ToolType.FUNCTION, + "function": function, + } + return ChatMessage( + role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call] + ) + # generate response including tool call context + else: + content = "\n".join([message.content for message in messages]) + return ChatMessage(role=MessageRole.ASSISTANT, content=content) + + +class MyAgent(Agent): + """Mock agent for testing built-in actions.""" + + @prompt + @staticmethod + def prompt() -> Prompt: + """Prompt can be used in action or chat model.""" + return Prompt.from_text( + name="prompt", + text="Please call the appropriate tool to do the following task: {task}", + ) + + @chat_model + @staticmethod + def chat_model() -> Tuple[Type[BaseChatModel], Dict[str, Any]]: + """ChatModel can be used in action.""" + return MockChatModel, { + "name": "chat_model", + "prompt": "prompt", + "tools": ["add"], + } + + @tool + @staticmethod + def add(a: int, b: int) -> int: + """Calculate the sum of a and b. + + Parameters + ---------- + a : int + The first operand + b : int + The second operand + + Returns: + ------- + int: + The sum of a and b + """ + return a + b + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: + """User defined action for processing input. + + In this action, we will send ChatRequestEvent to trigger built-in actions. + """ + input = event.input + ctx.send_event( + ChatRequestEvent( + model="chat_model", + messages=[ + ChatMessage( + role=MessageRole.USER, content=input, extra_args={"task": input} + ) + ], + ) + ) + + @action(ChatResponseEvent) + @staticmethod + def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: + """User defined action for processing chat model response.""" + input = event.response + ctx.send_event(OutputEvent(output=input.content)) + + +def test_built_in_actions() -> None: # noqa: D103 + env = AgentsExecutionEnvironment.get_execution_environment() + + input_list = [] + agent = MyAgent() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "0001", "value": "calculate the sum of 1 and 2."}) + + env.execute() + + assert output_list == [ + { + "0001": "calculate the sum of 1 and 2.\n" + "Please call the appropriate tool to do the following task: " + "calculate the sum of 1 and 2.\n" + "3" + } + ]