This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 763e591b1e0a593e9cd5a2f530a23c3f94d4a0a7
Author: WenjinXie <[email protected]>
AuthorDate: Fri Jan 16 22:02:23 2026 +0800

    [plan][java] Built-in actions support async execution.
---
 .../agents/api/agents/AgentExecutionOptions.java   | 10 ++++++++
 .../integration/test/ChatModelIntegrationTest.java | 26 +++++++++++--------
 .../resource/test/ChatModelCrossLanguageTest.java  | 26 +++++++++++--------
 .../flink/agents/plan/actions/ChatModelAction.java | 26 ++++++++++++++++++-
 .../plan/actions/ContextRetrievalAction.java       | 28 ++++++++++++++++++++-
 .../flink/agents/plan/actions/ToolCallAction.java  | 29 +++++++++++++++++++++-
 6 files changed, 122 insertions(+), 23 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
index 64880a5a..26991b60 100644
--- 
a/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
+++ 
b/api/src/main/java/org/apache/flink/agents/api/agents/AgentExecutionOptions.java
@@ -29,4 +29,14 @@ public class AgentExecutionOptions {
 
     public static final ConfigOption<Integer> MAX_RETRIES =
             new ConfigOption<>("max-retries", Integer.class, 3);
+
+    // Async execution is supported on jdk >= 21, so set default false here.
+    public static final ConfigOption<Boolean> CHAT_ASYNC =
+            new ConfigOption<>("chat.async", Boolean.class, true);
+
+    public static final ConfigOption<Boolean> TOOL_CALL_ASYNC =
+            new ConfigOption<>("tool-call.async", Boolean.class, true);
+
+    public static final ConfigOption<Boolean> RAG_ASYNC =
+            new ConfigOption<>("rag.async", Boolean.class, true);
 }
diff --git 
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
 
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
index cad885ac..c843b970 100644
--- 
a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
+++ 
b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ChatModelIntegrationTest.java
@@ -31,6 +31,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 import static 
org.apache.flink.agents.integration.test.ChatModelIntegrationAgent.OLLAMA_MODEL;
@@ -103,18 +104,23 @@ public class ChatModelIntegrationTest extends 
OllamaPreparationUtils {
     public void checkResult(CloseableIterator<Object> results) {
         List<String> expectedWords =
                 List.of("77", "37", "89", "23", "68", "22", "26", "22", "23", 
"");
+        List<String> responses = new ArrayList<>();
+        while (results.hasNext()) {
+            responses.add((String) results.next());
+        }
+
+        Assertions.assertEquals(
+                expectedWords.size(),
+                responses.size(),
+                String.format(
+                        "LLM response count is mismatch," + "the responses are 
%s", responses));
+
+        String text = String.join("\n", responses);
         for (String expected : expectedWords) {
             Assertions.assertTrue(
-                    results.hasNext(), "Output messages count %s is less than 
expected.");
-            String res = (String) results.next();
-            if (res.contains("error") || res.contains("parameters")) {
-                LOG.warn(res);
-            } else {
-                Assertions.assertTrue(
-                        res.contains(expected),
-                        String.format(
-                                "Groud truth %s is not contained in answer 
{%s}", expected, res));
-            }
+                    text.contains(expected),
+                    String.format(
+                            "Groud truth %s is not contained in answer {%s}", 
expected, text));
         }
     }
 }
diff --git 
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
 
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
index 7d7ec1d0..d62a9726 100644
--- 
a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
+++ 
b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java
@@ -30,6 +30,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 import static 
org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent.OLLAMA_MODEL;
@@ -82,18 +83,23 @@ public class ChatModelCrossLanguageTest {
 
     public void checkResult(CloseableIterator<Object> results) {
         List<String> expectedWords = List.of("77", "22", "");
+        List<String> responses = new ArrayList<>();
+        while (results.hasNext()) {
+            responses.add((String) results.next());
+        }
+
+        Assertions.assertEquals(
+                expectedWords.size(),
+                responses.size(),
+                String.format(
+                        "LLM response count is mismatch," + "the responses are 
%s", responses));
+
+        String text = String.join("\n", responses);
         for (String expected : expectedWords) {
             Assertions.assertTrue(
-                    results.hasNext(), "Output messages count %s is less than 
expected.");
-            String res = (String) results.next();
-            if (res.contains("error") || res.contains("parameters")) {
-                LOG.warn(res);
-            } else {
-                Assertions.assertTrue(
-                        res.contains(expected),
-                        String.format(
-                                "Groud truth %s is not contained in answer 
{%s}", expected, res));
-            }
+                    text.contains(expected),
+                    String.format(
+                            "Groud truth %s is not contained in answer {%s}", 
expected, text));
         }
     }
 }
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
index 7d5b34c9..9bde346a 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ChatModelAction.java
@@ -26,6 +26,7 @@ import org.apache.flink.agents.api.agents.OutputSchema;
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.chat.model.BaseChatModelSetup;
+import org.apache.flink.agents.api.context.DurableCallable;
 import org.apache.flink.agents.api.context.MemoryObject;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.api.event.ChatRequestEvent;
@@ -196,6 +197,8 @@ public class ChatModelAction {
         BaseChatModelSetup chatModel =
                 (BaseChatModelSetup) ctx.getResource(model, 
ResourceType.CHAT_MODEL);
 
+        boolean chatAsync = 
ctx.getConfig().get(AgentExecutionOptions.CHAT_ASYNC);
+
         Agent.ErrorHandlingStrategy strategy =
                 
ctx.getConfig().get(AgentExecutionOptions.ERROR_HANDLING_STRATEGY);
         int numRetries = 0;
@@ -210,7 +213,28 @@ public class ChatModelAction {
 
         for (int attempt = 0; attempt < numRetries + 1; attempt++) {
             try {
-                response = chatModel.chat(messages, Map.of());
+                if (chatAsync) {
+                    response =
+                            ctx.durableExecuteAsync(
+                                    new DurableCallable<>() {
+                                        @Override
+                                        public String getId() {
+                                            return "chat-async";
+                                        }
+
+                                        @Override
+                                        public Class<ChatMessage> 
getResultClass() {
+                                            return ChatMessage.class;
+                                        }
+
+                                        @Override
+                                        public ChatMessage call() throws 
Exception {
+                                            return chatModel.chat(messages, 
Map.of());
+                                        }
+                                    });
+                } else {
+                    response = chatModel.chat(messages, Map.of());
+                }
                 // only generate structured output for final response.
                 if (outputSchema != null && response.getToolCalls().isEmpty()) 
{
                     response = generateStructuredOutput(response, 
outputSchema);
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java
 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java
index 504011d3..72463010 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ContextRetrievalAction.java
@@ -19,6 +19,8 @@
 package org.apache.flink.agents.plan.actions;
 
 import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.api.context.DurableCallable;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.api.event.ContextRetrievalRequestEvent;
 import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent;
@@ -46,6 +48,8 @@ public class ContextRetrievalAction {
     public static void processContextRetrievalRequest(Event event, 
RunnerContext ctx)
             throws Exception {
         if (event instanceof ContextRetrievalRequestEvent) {
+            boolean ragAsync = 
ctx.getConfig().get(AgentExecutionOptions.RAG_ASYNC);
+
             final ContextRetrievalRequestEvent contextRetrievalRequestEvent =
                     (ContextRetrievalRequestEvent) event;
 
@@ -60,7 +64,29 @@ public class ContextRetrievalAction {
                             contextRetrievalRequestEvent.getQuery(),
                             contextRetrievalRequestEvent.getMaxResults());
 
-            final VectorStoreQueryResult result = 
vectorStore.query(vectorStoreQuery);
+            VectorStoreQueryResult result;
+            if (ragAsync) {
+                result =
+                        ctx.durableExecuteAsync(
+                                new DurableCallable<VectorStoreQueryResult>() {
+                                    @Override
+                                    public String getId() {
+                                        return "rag-async";
+                                    }
+
+                                    @Override
+                                    public Class<VectorStoreQueryResult> 
getResultClass() {
+                                        return VectorStoreQueryResult.class;
+                                    }
+
+                                    @Override
+                                    public VectorStoreQueryResult call() 
throws Exception {
+                                        return 
vectorStore.query(vectorStoreQuery);
+                                    }
+                                });
+            } else {
+                result = vectorStore.query(vectorStoreQuery);
+            }
 
             ctx.sendEvent(
                     new ContextRetrievalResponseEvent(
diff --git 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java
index 592c8214..f5fba2da 100644
--- 
a/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java
+++ 
b/plan/src/main/java/org/apache/flink/agents/plan/actions/ToolCallAction.java
@@ -17,6 +17,8 @@
  */
 package org.apache.flink.agents.plan.actions;
 
+import org.apache.flink.agents.api.agents.AgentExecutionOptions;
+import org.apache.flink.agents.api.context.DurableCallable;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.api.event.ToolRequestEvent;
 import org.apache.flink.agents.api.event.ToolResponseEvent;
@@ -44,6 +46,8 @@ public class ToolCallAction {
 
     @SuppressWarnings("unchecked")
     public static void processToolRequest(ToolRequestEvent event, 
RunnerContext ctx) {
+        boolean toolCallAsync = 
ctx.getConfig().get(AgentExecutionOptions.TOOL_CALL_ASYNC);
+
         Map<String, Boolean> success = new HashMap<>();
         Map<String, String> error = new HashMap<>();
         Map<String, ToolResponse> responses = new HashMap<>();
@@ -70,7 +74,30 @@ public class ToolCallAction {
 
             if (tool != null) {
                 try {
-                    ToolResponse response = tool.call(new 
ToolParameters(arguments));
+                    ToolResponse response;
+                    if (toolCallAsync) {
+                        final Tool toolRef = tool;
+                        response =
+                                ctx.durableExecuteAsync(
+                                        new DurableCallable<>() {
+                                            @Override
+                                            public String getId() {
+                                                return "tool-call-async";
+                                            }
+
+                                            @Override
+                                            public Class<ToolResponse> 
getResultClass() {
+                                                return ToolResponse.class;
+                                            }
+
+                                            @Override
+                                            public ToolResponse call() throws 
Exception {
+                                                return toolRef.call(new 
ToolParameters(arguments));
+                                            }
+                                        });
+                    } else {
+                        response = tool.call(new ToolParameters(arguments));
+                    }
                     success.put(id, true);
                     responses.put(id, response);
                 } catch (Exception e) {

Reply via email to