This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit d84e2fc1f6ce72ed7e3762d23b1f05880db28d7f Author: youjin <[email protected]> AuthorDate: Mon Jan 19 14:10:28 2026 +0800 [Feature][runtime] Support the use of Python MCP in Java Co-authored-by: Ioannis Stavrakantonakis <[email protected]> --- api/pom.xml | 53 --------- .../flink/agents/api/annotation/MCPServer.java | 25 ++++- .../apache/flink/agents/api/resource/Constant.java | 2 +- .../resource/test/ChatModelCrossLanguageTest.java | 2 +- ...java => CrossLanguageTestPreparationUtils.java} | 39 ++++++- .../resource/test/EmbeddingCrossLanguageTest.java | 2 +- .../resource/test/MCPCrossLanguageAgent.java | 85 +++++++++++++++ ...LanguageTest.java => MCPCrossLanguageTest.java} | 54 ++++------ .../test/VectorStoreCrossLanguageTest.java | 2 +- .../src/test/resources/mcp_server.py | 50 +++++++++ .../org/apache/flink/agents/plan/AgentPlan.java | 119 ++++++++++++++++++--- .../plan/resource/python/PythonMCPPrompt.java | 80 ++++++++++++++ .../plan/resource/python/PythonMCPServer.java | 91 ++++++++++++++++ .../agents/plan/resource/python/PythonMCPTool.java | 82 ++++++++++++++ .../resourceprovider/PythonResourceProvider.java | 11 +- python/flink_agents/api/tools/utils.py | 37 +++++++ python/flink_agents/runtime/python_java_utils.py | 16 ++- 17 files changed, 635 insertions(+), 115 deletions(-) diff --git a/api/pom.xml b/api/pom.xml index c49ab137..8cf08d0a 100644 --- a/api/pom.xml +++ b/api/pom.xml @@ -64,57 +64,4 @@ under the License. </dependency> </dependencies> - <profiles> - <profile> - <id>java-11</id> - <activation> - <jdk>[11,17)</jdk> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <excludes> - <exclude>org/apache/flink/agents/api/annotation/MCPServer.java</exclude> - </excludes> - </configuration> - </plugin> - </plugins> - </build> - </profile> - - <!-- Profile for generating jdk11 classifier jar (excludes MCPServer) --> - <profile> - <id>java-11-target</id> - <activation> - <property> - <name>java-11-target</name> - </property> - </activation> - <build> - <plugins> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <!-- Default is already Java 11, no need to set source/target --> - <excludes> - <exclude>org/apache/flink/agents/api/annotation/MCPServer.java</exclude> - </excludes> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <configuration> - <classifier>jdk11</classifier> - </configuration> - </plugin> - </plugins> - </build> - </profile> - </profiles> - </project> \ No newline at end of file diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java b/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java index f427c96c..17f6487c 100644 --- a/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java @@ -58,11 +58,26 @@ import java.lang.annotation.Target; * }</pre> * * <p>This is the Java equivalent of Python's {@code @mcp_server} decorator. - * - * @see org.apache.flink.agents.integrations.mcp.MCPServer - * @see org.apache.flink.agents.integrations.mcp.MCPTool - * @see org.apache.flink.agents.integrations.mcp.MCPPrompt */ @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) -public @interface MCPServer {} +public @interface MCPServer { + /** + * Specifies the implementation language for the MCP server connection. + * + * <p>Supported values: + * + * <ul> + * <li><b>"auto"</b> (default): Automatically selects the language based on JDK version. Uses + * Python for JDK 16 and below, and Java for JDK 17+. + * <li><b>"python"</b>: Forces the use of Python-based MCP server implementation. + * <li><b>"java"</b>: Forces the use of Java-based MCP server implementation. + * </ul> + * + * <p>The language selection affects how the agent plan communicates with the MCP server and + * which runtime dependencies are required. + * + * @return the language identifier ("auto", "python", or "java") + */ + String lang() default "auto"; +} diff --git a/api/src/main/java/org/apache/flink/agents/api/resource/Constant.java b/api/src/main/java/org/apache/flink/agents/api/resource/Constant.java index 32bc9e5d..34e62a32 100644 --- a/api/src/main/java/org/apache/flink/agents/api/resource/Constant.java +++ b/api/src/main/java/org/apache/flink/agents/api/resource/Constant.java @@ -81,5 +81,5 @@ public class Constant { "org.apache.flink.agents.integrations.vectorstores.elasticsearch.ElasticsearchVectorStore"; // MCP - public static String MCP_SERVER = "org.apache.flink.agents.integrations.mcp.MCPServer"; + public static String MCP_SERVER = "DECIDE_IN_RUNTIME_MCPServer"; } diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java index d62a9726..c471ee37 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/ChatModelCrossLanguageTest.java @@ -34,7 +34,7 @@ import java.util.ArrayList; import java.util.List; import static org.apache.flink.agents.resource.test.ChatModelCrossLanguageAgent.OLLAMA_MODEL; -import static org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.pullModel; public class ChatModelCrossLanguageTest { private static final Logger LOG = LoggerFactory.getLogger(ChatModelCrossLanguageTest.class); diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/CrossLanguageTestPreparationUtils.java similarity index 53% rename from e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java rename to e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/CrossLanguageTestPreparationUtils.java index a42809ef..86690a7f 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/OllamaPreparationUtils.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/CrossLanguageTestPreparationUtils.java @@ -24,13 +24,14 @@ import java.io.IOException; import java.util.Objects; import java.util.concurrent.TimeUnit; -public class OllamaPreparationUtils { - private static final Logger LOG = LoggerFactory.getLogger(OllamaPreparationUtils.class); +public class CrossLanguageTestPreparationUtils { + private static final Logger LOG = + LoggerFactory.getLogger(CrossLanguageTestPreparationUtils.class); public static boolean pullModel(String model) throws IOException { String path = Objects.requireNonNull( - OllamaPreparationUtils.class + CrossLanguageTestPreparationUtils.class .getClassLoader() .getResource("ollama_pull_model.sh")) .getPath(); @@ -44,4 +45,36 @@ public class OllamaPreparationUtils { } return false; } + + public static Process startMCPServer() { + LOG.info("MCP Server is already running"); + + String path = + Objects.requireNonNull( + CrossLanguageTestPreparationUtils.class + .getClassLoader() + .getResource("mcp_server.py")) + .getPath(); + ProcessBuilder builder = new ProcessBuilder("python", path); + builder.redirectErrorStream(true); + + try { + Process mcpServerProcess = builder.start(); + // Give the server a moment to start up + Thread.sleep(2000); + + if (mcpServerProcess.isAlive()) { + LOG.info("MCP Server started successfully with PID: {}", mcpServerProcess.pid()); + return mcpServerProcess; + } else { + LOG.warn( + "MCP Server process exited immediately with code: {}", + mcpServerProcess.exitValue()); + return null; + } + } catch (Exception e) { + LOG.warn("Start MCP Server failed, will skip test", e); + return null; + } + } } diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java index 8806ac99..7293fcb7 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java @@ -29,7 +29,7 @@ import org.junit.jupiter.api.Test; import java.io.IOException; import java.util.Map; -import static org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.pullModel; /** * Example application that applies {@link EmbeddingCrossLanguageAgent} to a DataStream of prompts. diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageAgent.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageAgent.java new file mode 100644 index 00000000..82123fa8 --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageAgent.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.resource.test; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.agents.Agent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.annotation.MCPServer; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.junit.jupiter.api.Assertions; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.api.resource.Constant.MCP_SERVER; + +public class MCPCrossLanguageAgent extends Agent { + @MCPServer(lang = "python") + public static ResourceDescriptor pythonMCPServer() { + return ResourceDescriptor.Builder.newBuilder(MCP_SERVER) + .addInitialArgument("endpoint", "http://127.0.0.1:8000/mcp") + .build(); + } + + @Action(listenEvents = {InputEvent.class}) + public static void process(InputEvent event, RunnerContext ctx) throws Exception { + Map<String, Object> testResult = new HashMap<>(); + try { + Tool add = (Tool) ctx.getResource("add", ResourceType.TOOL); + + Assertions.assertTrue(add.getDescription().contains("Get the detailed information")); + + ToolResponse response = add.call(new ToolParameters(Map.of("a", 1, "b", 2))); + Assertions.assertTrue(response.getResult().toString().contains("3")); + System.out.println("[TEST] MCP Tools PASSED"); + + Prompt askSum = (Prompt) ctx.getResource("ask_sum", ResourceType.PROMPT); + List<ChatMessage> chatMessages = + askSum.formatMessages(MessageRole.USER, Map.of("a", "1", "b", "2")); + Assertions.assertEquals(1, chatMessages.size()); + Assertions.assertEquals( + "Can you please calculate the sum of 1 and 2?", + chatMessages.get(0).getContent()); + Assertions.assertEquals(MessageRole.USER, chatMessages.get(0).getRole()); + + String content = askSum.formatString(Map.of("a", "3", "b", "4")); + Assertions.assertEquals("Can you please calculate the sum of 3 and 4?", content); + System.out.println("[TEST] MCP Prompts PASSED"); + + testResult.put("test_status", "PASSED"); + ctx.sendEvent(new OutputEvent(testResult)); + } catch (Exception e) { + testResult.put("test_status", "FAILED"); + testResult.put("error", e.getMessage()); + ctx.sendEvent(new OutputEvent(testResult)); + System.err.printf("[TEST] MCP Cross Language test FAILED: %s%n", e.getMessage()); + throw e; + } + } +} diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageTest.java similarity index 54% copy from e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java copy to e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageTest.java index 8806ac99..a8092c0c 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/EmbeddingCrossLanguageTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/MCPCrossLanguageTest.java @@ -26,43 +26,28 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; -import java.io.IOException; import java.util.Map; -import static org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.startMCPServer; -/** - * Example application that applies {@link EmbeddingCrossLanguageAgent} to a DataStream of prompts. - */ -public class EmbeddingCrossLanguageTest { - - private final boolean ollamaReady; +/** Example application that applies {@link MCPCrossLanguageAgent} to a DataStream. */ +public class MCPCrossLanguageTest { + private final Process mcpServerProcess; - public EmbeddingCrossLanguageTest() throws IOException { - ollamaReady = pullModel(EmbeddingCrossLanguageAgent.OLLAMA_MODEL); + public MCPCrossLanguageTest() { + this.mcpServerProcess = startMCPServer(); } @Test - public void testEmbeddingIntegration() throws Exception { - Assumptions.assumeTrue(ollamaReady, "Ollama Server information is not provided"); + public void testMCPIntegration() throws Exception { + Assumptions.assumeTrue(mcpServerProcess != null, "MCP Server is not running"); // Create the execution environment StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); - // Use prompts that exercise embedding generation and similarity checks - DataStream<String> inputStream = - env.fromData( - "Generate embedding for: 'Machine learning'", - "Generate embedding for: 'Deep learning techniques'", - "Find texts similar to: 'neural networks'", - "Produce embedding and return top-3 similar items for: 'natural language processing'", - "Generate embedding for: 'hello world'", - "Compare similarity between 'cat' and 'dog'", - "Create embedding for: 'space exploration'", - "Find nearest neighbors for: 'artificial intelligence'", - "Generate embedding for: 'data science'", - "Random embedding test"); + // Use prompts that utilize the MCP tool and perform prompt checks. + DataStream<String> inputStream = env.fromData("An input message to invoke the Test Action"); // Create agents execution environment AgentsExecutionEnvironment agentsEnv = @@ -72,7 +57,7 @@ public class EmbeddingCrossLanguageTest { DataStream<Object> outputStream = agentsEnv .fromDataStream(inputStream, (KeySelector<String, String>) value -> value) - .apply(new EmbeddingCrossLanguageAgent()) + .apply(new MCPCrossLanguageAgent()) .toDataStream(); // Collect the results @@ -82,16 +67,19 @@ public class EmbeddingCrossLanguageTest { agentsEnv.execute(); checkResult(results); + + mcpServerProcess.destroy(); } @SuppressWarnings("unchecked") private void checkResult(CloseableIterator<Object> results) { - for (int i = 1; i <= 10; i++) { - Assertions.assertTrue( - results.hasNext(), - String.format("Output messages count %s is less than expected 10.", i)); - Map<String, Object> res = (Map<String, Object>) results.next(); - Assertions.assertEquals("PASSED", res.get("test_status")); - } + Assertions.assertTrue( + results.hasNext(), "No output received from VectorStoreIntegrationAgent"); + + Object obj = results.next(); + Assertions.assertInstanceOf(Map.class, obj, "Output must be a Map"); + + java.util.Map<String, Object> res = (java.util.Map<String, Object>) obj; + Assertions.assertEquals("PASSED", res.get("test_status")); } } diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageTest.java b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageTest.java index 0d53d49a..554462d9 100644 --- a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/java/org/apache/flink/agents/resource/test/VectorStoreCrossLanguageTest.java @@ -32,7 +32,7 @@ import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; import java.util.Map; -import static org.apache.flink.agents.resource.test.OllamaPreparationUtils.pullModel; +import static org.apache.flink.agents.resource.test.CrossLanguageTestPreparationUtils.pullModel; import static org.apache.flink.agents.resource.test.VectorStoreCrossLanguageAgent.OLLAMA_MODEL; public class VectorStoreCrossLanguageTest { diff --git a/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py new file mode 100644 index 00000000..4a6c331b --- /dev/null +++ b/e2e-test/flink-agents-end-to-end-tests-resource-cross-language/src/test/resources/mcp_server.py @@ -0,0 +1,50 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# + +try: + import dotenv + dotenv.load_dotenv() +except ImportError: + # dotenv is optional for this test server + pass + +from mcp.server.fastmcp import FastMCP + +# Create MCP server +mcp = FastMCP("BasicServer") + + [email protected]() +def ask_sum(a: int, b: int) -> str: + """Prompt of add tool.""" + return f"Can you please calculate the sum of {a} and {b}?" + [email protected]() +async def add(a: int, b: int) -> int: + """Get the detailed information of a specified IP address. + + Args: + a: The first operand. + b: The second operand. + + Returns: + int: The sum of a and b. + """ + return a + b + +mcp.run("streamable-http") diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 9fa05a31..2cd77915 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -35,6 +35,7 @@ import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.plan.actions.ChatModelAction; import org.apache.flink.agents.plan.actions.ContextRetrievalAction; import org.apache.flink.agents.plan.actions.ToolCallAction; +import org.apache.flink.agents.plan.resource.python.PythonMCPServer; import org.apache.flink.agents.plan.resourceprovider.JavaResourceProvider; import org.apache.flink.agents.plan.resourceprovider.JavaSerializableResourceProvider; import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; @@ -44,6 +45,8 @@ import org.apache.flink.agents.plan.serializer.AgentPlanJsonSerializer; import org.apache.flink.agents.plan.tools.FunctionTool; import org.apache.flink.agents.plan.tools.ToolMetadataFactory; import org.apache.flink.api.java.tuple.Tuple3; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.ObjectInputStream; @@ -58,6 +61,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; import static org.apache.flink.agents.api.resource.ResourceType.MCP_SERVER; import static org.apache.flink.agents.api.resource.ResourceType.PROMPT; @@ -67,6 +71,9 @@ import static org.apache.flink.agents.api.resource.ResourceType.TOOL; @JsonSerialize(using = AgentPlanJsonSerializer.class) @JsonDeserialize(using = AgentPlanJsonDeserializer.class) public class AgentPlan implements Serializable { + private static final Logger LOG = LoggerFactory.getLogger(AgentPlan.class); + private static final String JAVA_MCP_SERVER_CLASS_NAME = + "org.apache.flink.agents.integrations.mcp.MCPServer"; /** Mapping from action name to action itself. */ private Map<String, Action> actions; @@ -133,8 +140,60 @@ public class AgentPlan implements Serializable { this.config = config; } - public void setPythonResourceAdapter(PythonResourceAdapter adapter) { + public void setPythonResourceAdapter(PythonResourceAdapter adapter) throws Exception { this.pythonResourceAdapter = adapter; + Map<String, ResourceProvider> servers = resourceProviders.get(MCP_SERVER); + if (servers == null) { + return; + } + servers.values().stream() + .filter(PythonResourceProvider.class::isInstance) + .map(PythonResourceProvider.class::cast) + .forEach( + provider -> { + provider.setPythonResourceAdapter(adapter); + + // Get tools and prompts from server + try { + PythonMCPServer server = + (PythonMCPServer) + provider.provide( + (String anotherName, + ResourceType anotherType) -> { + try { + return this.getResource( + anotherName, anotherType); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + // Add tools to cache + server.listTools() + .forEach( + tool -> + resourceCache + .computeIfAbsent( + TOOL, + k -> + new ConcurrentHashMap<>()) + .put(tool.getName(), tool)); + + // Add prompts to cache + server.listPrompts() + .forEach( + prompt -> + resourceCache + .computeIfAbsent( + PROMPT, + k -> + new ConcurrentHashMap<>()) + .put(prompt.getName(), prompt)); + } catch (Exception e) { + throw new RuntimeException( + "Failed to process Python MCP server in Java", e); + } + }); } public Map<String, Action> getActions() { @@ -301,9 +360,21 @@ public class AgentPlan implements Serializable { } private void extractResource(ResourceType type, Method method) throws Exception { + extractResource(type, method, null); + } + + private void extractResource( + ResourceType type, + Method method, + Function<ResourceDescriptor, ResourceDescriptor> descriptorDecorator) + throws Exception { String name = method.getName(); ResourceProvider provider; ResourceDescriptor descriptor = (ResourceDescriptor) method.invoke(null); + + descriptor = + descriptorDecorator != null ? descriptorDecorator.apply(descriptor) : descriptor; + if (PythonResourceWrapper.class.isAssignableFrom(Class.forName(descriptor.getClazz()))) { provider = new PythonResourceProvider(name, type, descriptor); } else { @@ -329,11 +400,16 @@ public class AgentPlan implements Serializable { addResourceProvider(provider); } - private void extractMCPServer(Method method) throws Exception { + private void extractJavaMCPServer(Method method) throws Exception { // Use reflection to handle MCP classes to support Java 11 without MCP String name = method.getName(); ResourceDescriptor descriptor = (ResourceDescriptor) method.invoke(null); + descriptor = + new ResourceDescriptor( + descriptor.getModule(), + JAVA_MCP_SERVER_CLASS_NAME, + new HashMap<>(descriptor.getInitialArguments())); JavaResourceProvider provider = new JavaResourceProvider(name, MCP_SERVER, descriptor); addResourceProvider(provider); @@ -448,18 +524,33 @@ public class AgentPlan implements Serializable { extractResource(ResourceType.EMBEDDING_MODEL_CONNECTION, method); } else if (method.isAnnotationPresent(VectorStore.class)) { extractResource(ResourceType.VECTOR_STORE, method); - } else if (Modifier.isStatic(method.getModifiers())) { - // Check for MCPServer annotation using reflection to support Java 11 without MCP - try { - Class<?> mcpServerAnnotation = - Class.forName("org.apache.flink.agents.api.annotation.MCPServer"); - if (method.isAnnotationPresent( - (Class<? extends java.lang.annotation.Annotation>) - mcpServerAnnotation)) { - extractMCPServer(method); - } - } catch (ClassNotFoundException e) { - // MCP annotation not available (Java 11 build), skip MCP processing + } else if (method.isAnnotationPresent(MCPServer.class)) { + // Check the MCPServer annotation version to determine which version to use. + MCPServer MCPServerAnnotation = method.getAnnotation(MCPServer.class); + String lang = MCPServerAnnotation.lang(); + int javaVersion = Runtime.version().feature(); + + if (lang.equalsIgnoreCase("auto")) { + lang = javaVersion >= 17 ? "java" : "python"; + } else if (lang.equalsIgnoreCase("java") && javaVersion < 17) { + throw new UnsupportedOperationException( + "Java version is less than 17, please use python MCP server."); + } + + if (lang.equalsIgnoreCase("java")) { + extractJavaMCPServer(method); + } else { + LOG.warn( + "Using the Python MCP server with cross-language support. The Java version is " + + javaVersion); + extractResource( + ResourceType.MCP_SERVER, + method, + desc -> + new ResourceDescriptor( + desc.getModule(), + PythonMCPServer.class.getName(), + new HashMap<>(desc.getInitialArguments()))); } } } diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java new file mode 100644 index 00000000..625a89fd --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPPrompt.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.resource.python; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import pemja.core.object.PyObject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class PythonMCPPrompt extends Prompt implements PythonResourceWrapper { + private static final String FROM_JAVA_MESSAGE_ROLE = "python_java_utils.from_java_message_role"; + + private final PyObject prompt; + private final PythonResourceAdapter adapter; + private String name; + + public PythonMCPPrompt(PythonResourceAdapter adapter, PyObject prompt) { + this.adapter = adapter; + this.prompt = prompt; + } + + @Override + public Object getPythonResource() { + return prompt; + } + + public String getName() { + if (name == null) { + name = prompt.getAttr("name").toString(); + } + return name; + } + + @Override + public String formatString(Map<String, String> kwargs) { + Map<String, Object> parameters = new HashMap<>(kwargs); + return adapter.callMethod(prompt, "format_string", parameters).toString(); + } + + @Override + public List<ChatMessage> formatMessages(MessageRole defaultRole, Map<String, String> kwargs) { + Map<String, Object> parameters = new HashMap<>(kwargs); + Object pythonRole = adapter.invoke(FROM_JAVA_MESSAGE_ROLE, defaultRole); + parameters.put("role", pythonRole); + + Object result = adapter.callMethod(prompt, "format_messages", parameters); + if (result instanceof List) { + List<Object> pythonMessages = (List<Object>) result; + List<ChatMessage> messages = new ArrayList<>(pythonMessages.size()); + for (Object pythonMessage : pythonMessages) { + messages.add(adapter.fromPythonChatMessage(pythonMessage)); + } + return messages; + } + return Collections.emptyList(); + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java new file mode 100644 index 00000000..6ce0f4da --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPServer.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.resource.python; + +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import pemja.core.object.PyObject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +public class PythonMCPServer extends Resource implements PythonResourceWrapper { + private final PyObject server; + private final PythonResourceAdapter adapter; + + /** + * Creates a new PythonMCPServer. + * + * @param adapter The Python resource adapter (required by PythonResourceProvider's + * reflection-based instantiation but not used directly in this implementation) + * @param server The Python MCP Server object + * @param descriptor The resource descriptor + * @param getResource Function to retrieve resources by name and type + */ + public PythonMCPServer( + PythonResourceAdapter adapter, + PyObject server, + ResourceDescriptor descriptor, + BiFunction<String, ResourceType, Resource> getResource) { + super(descriptor, getResource); + this.server = server; + this.adapter = adapter; + } + + @SuppressWarnings("unchecked") + public List<PythonMCPTool> listTools() { + Object result = adapter.callMethod(server, "list_tools", Collections.emptyMap()); + if (result instanceof List) { + List<Object> pythonTools = (List<Object>) result; + List<PythonMCPTool> tools = new ArrayList<>(pythonTools.size()); + for (Object pyTool : pythonTools) { + tools.add(new PythonMCPTool(adapter, (PyObject) pyTool)); + } + return tools; + } + return Collections.emptyList(); + } + + public List<PythonMCPPrompt> listPrompts() { + Object result = adapter.callMethod(server, "list_prompts", Collections.emptyMap()); + if (result instanceof List) { + List<Object> pythonPrompts = (List<Object>) result; + List<PythonMCPPrompt> prompts = new ArrayList<>(pythonPrompts.size()); + for (Object pythonPrompt : pythonPrompts) { + prompts.add(new PythonMCPPrompt(adapter, (PyObject) pythonPrompt)); + } + return prompts; + } + return Collections.emptyList(); + } + + @Override + public Object getPythonResource() { + return server; + } + + @Override + public ResourceType getResourceType() { + return ResourceType.MCP_SERVER; + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java new file mode 100644 index 00000000..89a5435d --- /dev/null +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonMCPTool.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.plan.resource.python; + +import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; +import org.apache.flink.agents.api.resource.python.PythonResourceWrapper; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import pemja.core.object.PyObject; + +import java.util.HashMap; +import java.util.Map; + +public class PythonMCPTool extends Tool implements PythonResourceWrapper { + private static final String GET_JAVA_TOOL_META = + "python_java_utils.get_java_tool_metadata_from_tool"; + private final PyObject tool; + private final PythonResourceAdapter adapter; + + /** + * Creates a new PythonMCPServer. + * + * @param adapter The Python resource adapter (required by PythonResourceProvider's + * reflection-based instantiation but not used directly in this implementation) + * @param tool The Python MCP tool object + */ + public PythonMCPTool(PythonResourceAdapter adapter, PyObject tool) { + super(getToolMetadata(adapter, tool)); + this.tool = tool; + this.adapter = adapter; + } + + @SuppressWarnings("unchecked") + private static ToolMetadata getToolMetadata(PythonResourceAdapter adapter, PyObject tool) { + Map<String, String> metadata = + (Map<String, String>) adapter.invoke(GET_JAVA_TOOL_META, tool); + return new ToolMetadata( + metadata.get("name"), metadata.get("description"), metadata.get("inputSchema")); + } + + @Override + public ToolResponse call(ToolParameters parameters) { + Map<String, Object> kwargs = new HashMap<>(); + for (String paramName : parameters.getParameterNames()) { + kwargs.put(paramName, parameters.getParameter(paramName)); + } + try { + Object result = adapter.callMethod(tool, "call", kwargs); + return ToolResponse.success(result); + } catch (Exception e) { + return ToolResponse.error(e); + } + } + + @Override + public Object getPythonResource() { + return tool; + } + + @Override + public ToolType getToolType() { + return ToolType.MCP; + } +} diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java index 5e9a0cc3..ca87db9f 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resourceprovider/PythonResourceProvider.java @@ -27,6 +27,7 @@ import org.apache.flink.agents.api.resource.ResourceDescriptor; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.api.resource.python.PythonResourceAdapter; import org.apache.flink.agents.api.vectorstores.python.PythonCollectionManageableVectorStore; +import org.apache.flink.agents.plan.resource.python.PythonMCPServer; import pemja.core.object.PyObject; import java.lang.reflect.Constructor; @@ -44,6 +45,8 @@ import static org.apache.flink.util.Preconditions.checkState; * class, and initialization arguments. */ public class PythonResourceProvider extends ResourceProvider { + private static final String MCP_MODULE = "flink_agents.integrations.mcp.mcp"; + private static final String MCP_CLASS = "MCPServer"; private final ResourceDescriptor descriptor; private static final Map<ResourceType, Class<?>> RESOURCE_TYPE_TO_CLASS = @@ -52,7 +55,8 @@ public class PythonResourceProvider extends ResourceProvider { ResourceType.CHAT_MODEL_CONNECTION, PythonChatModelConnection.class, ResourceType.EMBEDDING_MODEL, PythonEmbeddingModelSetup.class, ResourceType.EMBEDDING_MODEL_CONNECTION, PythonEmbeddingModelConnection.class, - ResourceType.VECTOR_STORE, PythonCollectionManageableVectorStore.class); + ResourceType.VECTOR_STORE, PythonCollectionManageableVectorStore.class, + ResourceType.MCP_SERVER, PythonMCPServer.class); protected PythonResourceAdapter pythonResourceAdapter; @@ -84,6 +88,11 @@ public class PythonResourceProvider extends ResourceProvider { String pyModule = descriptor.getModule(); String pyClazz = descriptor.getClazz(); + if (getType() == ResourceType.MCP_SERVER) { + pyModule = MCP_MODULE; + pyClazz = MCP_CLASS; + } + // Extract module and class from kwargs if not provided in descriptor if (pyModule == null || pyModule.isEmpty()) { pyModule = (String) kwargs.remove("module"); diff --git a/python/flink_agents/api/tools/utils.py b/python/flink_agents/api/tools/utils.py index 6b5cd4b4..51ea1bd9 100644 --- a/python/flink_agents/api/tools/utils.py +++ b/python/flink_agents/api/tools/utils.py @@ -191,3 +191,40 @@ def create_model_from_java_tool_schema_str(name: str, schema_str: str) -> type[B type = TYPE_MAPPING.get(properties[param_name]["type"]) fields[param_name] = (type, FieldInfo(description=description)) return create_model(name, **fields) + +def create_java_tool_schema_str_from_model(model: type[BaseModel]) -> str: + """Create a java tool input schema string from a Pydantic model. + + This is the inverse function of create_model_from_java_tool_schema_str. + + Args: + model: A Pydantic BaseModel class + + Returns: + A JSON schema string compatible with Java tool input schema format + """ + REVERSE_TYPE_MAPPING = {v: k for k, v in TYPE_MAPPING.items()} + + properties = {} + for field_name, field_info in model.model_fields.items(): + field_type = field_info.annotation + + origin = typing.get_origin(field_type) + if origin is not None: + if origin is typing.Union: + args = typing.get_args(field_type) + non_none_types = [arg for arg in args if arg is not type(None)] + if non_none_types: + field_type = non_none_types[0] + + json_type = REVERSE_TYPE_MAPPING.get(field_type, "string") + + description = field_info.description + if description is None: + description = f"Parameter: {field_name}" + + properties[field_name] = {"type": json_type, "description": description} + + json_schema = {"properties": properties} + + return json.dumps(json_schema, ensure_ascii=False, indent=2) diff --git a/python/flink_agents/runtime/python_java_utils.py b/python/flink_agents/runtime/python_java_utils.py index dc320e7c..42b4a8a6 100644 --- a/python/flink_agents/runtime/python_java_utils.py +++ b/python/flink_agents/runtime/python_java_utils.py @@ -17,6 +17,7 @@ ################################################################################# import importlib import json +import typing from typing import Any, Callable, Dict import cloudpickle @@ -24,8 +25,11 @@ import cloudpickle from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.events.event import Event, InputEvent from flink_agents.api.resource import Resource, ResourceType, get_resource_class -from flink_agents.api.tools.tool import ToolMetadata -from flink_agents.api.tools.utils import create_model_from_java_tool_schema_str +from flink_agents.api.tools.tool import Tool, ToolMetadata +from flink_agents.api.tools.utils import ( + create_java_tool_schema_str_from_model, + create_model_from_java_tool_schema_str, +) from flink_agents.api.vector_stores.vector_store import ( Collection, Document, @@ -245,6 +249,14 @@ def from_java_collection(j_collection: Any) -> Collection: metadata=j_collection.getMetadata(), ) +def from_java_message_role(j_role: Any) -> MessageRole: + """Convert a Java message role to a Python message role.""" + return MessageRole(j_role.getValue()) + +def get_java_tool_metadata_from_tool(tool: Tool) -> typing.Dict[str, str]: + """Retrieve Java format tool metadata from a tool input schema string.""" + return {"name": tool.name, "description": tool.metadata.description, "inputSchema": create_java_tool_schema_str_from_model(tool.metadata.args_schema)} + def get_mode_value(query: VectorStoreQuery) -> str: """Get the mode value of a VectorStoreQuery.""" return query.mode.value
