This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ad8745d [Feature][Java] Added retries to remote calls in MCP server 
(#531)
6ad8745d is described below

commit 6ad8745d23d69f88131f97dfdb550921d2fed9f3
Author: Yash Anand <[email protected]>
AuthorDate: Sun Mar 1 21:22:16 2026 -0800

    [Feature][Java] Added retries to remote calls in MCP server (#531)
---
 .../org/apache/flink/agents/api/RetryExecutor.java | 207 ++++++++++++++
 .../apache/flink/agents/api/RetryExecutorTest.java | 277 +++++++++++++++++++
 .../flink/agents/integrations/mcp/MCPServer.java   | 299 +++++++++++++++------
 .../agents/integrations/mcp/MCPServerTest.java     |  27 ++
 4 files changed, 733 insertions(+), 77 deletions(-)

diff --git a/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java 
b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java
new file mode 100644
index 00000000..b369e1d1
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api;
+
+import java.net.ConnectException;
+import java.net.SocketTimeoutException;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.function.Predicate;
+
+/**
+ * A reusable utility for executing operations with retry logic and binary 
exponential backoff.
+ *
+ * <p>By default, the following exceptions are considered retryable:
+ *
+ * <ul>
+ *   <li>{@link SocketTimeoutException}
+ *   <li>{@link ConnectException}
+ *   <li>Exceptions whose message contains "HTTP 503" or "HTTP 429"
+ *   <li>Exceptions whose message contains "Connection reset", "Connection 
refused", or "Connection
+ *       timed out"
+ * </ul>
+ *
+ * <p>A custom retryable predicate can be provided to override this behavior.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * RetryExecutor executor = RetryExecutor.builder()
+ *     .maxRetries(3)
+ *     .initialBackoffMs(100)
+ *     .maxBackoffMs(10000)
+ *     .build();
+ *
+ * String result = executor.execute(() -> callRemoteService(), 
"callRemoteService");
+ * }</pre>
+ */
+public class RetryExecutor {
+
+    private static final Random RANDOM = new Random();
+
+    private static final int DEFAULT_MAX_RETRIES = 3;
+    private static final long DEFAULT_INITIAL_BACKOFF_MS = 100;
+    private static final long DEFAULT_MAX_BACKOFF_MS = 10000;
+
+    private final int maxRetries;
+    private final long initialBackoffMs;
+    private final long maxBackoffMs;
+    private final Predicate<Exception> retryablePredicate;
+
+    private RetryExecutor(
+            int maxRetries,
+            long initialBackoffMs,
+            long maxBackoffMs,
+            Predicate<Exception> retryablePredicate) {
+        this.maxRetries = maxRetries;
+        this.initialBackoffMs = initialBackoffMs;
+        this.maxBackoffMs = maxBackoffMs;
+        this.retryablePredicate =
+                retryablePredicate != null ? retryablePredicate : 
RetryExecutor::isRetryableDefault;
+    }
+
+    /** Creates a builder for {@link RetryExecutor}. */
+    public static Builder builder() {
+        return new Builder();
+    }
+
+    /** Creates a {@link RetryExecutor} with default settings. */
+    public static RetryExecutor withDefaults() {
+        return builder().build();
+    }
+
+    /**
+     * Execute an operation with retry logic.
+     *
+     * @param operation The operation to execute
+     * @param operationName Name of the operation for error messages
+     * @return The result of the operation
+     * @throws RuntimeException if all retries fail or a non-retryable 
exception occurs
+     */
+    public <T> T execute(Callable<T> operation, String operationName) {
+        int attempt = 0;
+        long window = initialBackoffMs;
+        Exception lastException = null;
+
+        while (attempt <= maxRetries) {
+            try {
+                return operation.call();
+            } catch (Exception e) {
+                lastException = e;
+                attempt++;
+
+                if (!retryablePredicate.test(e)) {
+                    throw new RuntimeException(
+                            String.format(
+                                    "Operation '%s' failed: %s", 
operationName, e.getMessage()),
+                            e);
+                }
+
+                if (attempt > maxRetries) {
+                    break;
+                }
+
+                // Binary Exponential Backoff: random wait from [0, window]
+                try {
+                    long sleepTime = (long) (RANDOM.nextDouble() * (window + 
1));
+                    Thread.sleep(sleepTime);
+                } catch (InterruptedException ie) {
+                    Thread.currentThread().interrupt();
+                    throw new RuntimeException(
+                            "Interrupted while retrying operation: " + 
operationName, ie);
+                }
+
+                window = Math.min(window * 2, maxBackoffMs);
+            }
+        }
+
+        throw new RuntimeException(
+                String.format(
+                        "Operation '%s' failed after %d retries: %s",
+                        operationName, maxRetries, lastException.getMessage()),
+                lastException);
+    }
+
+    public int getMaxRetries() {
+        return maxRetries;
+    }
+
+    public long getInitialBackoffMs() {
+        return initialBackoffMs;
+    }
+
+    public long getMaxBackoffMs() {
+        return maxBackoffMs;
+    }
+
+    /**
+     * Default retryable check.
+     *
+     * @param e The exception to check
+     * @return true if the operation should be retried
+     */
+    static boolean isRetryableDefault(Exception e) {
+        if (e instanceof SocketTimeoutException || e instanceof 
ConnectException) {
+            return true;
+        }
+        String message = e.getMessage();
+        if (message != null) {
+            if (message.contains("HTTP 503") || message.contains("HTTP 429")) {
+                return true;
+            }
+            return message.contains("Connection reset")
+                    || message.contains("Connection refused")
+                    || message.contains("Connection timed out");
+        }
+        return false;
+    }
+
+    /** Builder for {@link RetryExecutor}. */
+    public static class Builder {
+        private int maxRetries = DEFAULT_MAX_RETRIES;
+        private long initialBackoffMs = DEFAULT_INITIAL_BACKOFF_MS;
+        private long maxBackoffMs = DEFAULT_MAX_BACKOFF_MS;
+        private Predicate<Exception> retryablePredicate;
+
+        public Builder maxRetries(int maxRetries) {
+            this.maxRetries = maxRetries;
+            return this;
+        }
+
+        public Builder initialBackoffMs(long initialBackoffMs) {
+            this.initialBackoffMs = initialBackoffMs;
+            return this;
+        }
+
+        public Builder maxBackoffMs(long maxBackoffMs) {
+            this.maxBackoffMs = maxBackoffMs;
+            return this;
+        }
+
+        public Builder retryablePredicate(Predicate<Exception> 
retryablePredicate) {
+            this.retryablePredicate = retryablePredicate;
+            return this;
+        }
+
+        public RetryExecutor build() {
+            return new RetryExecutor(
+                    maxRetries, initialBackoffMs, maxBackoffMs, 
retryablePredicate);
+        }
+    }
+}
diff --git 
a/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java 
b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java
new file mode 100644
index 00000000..982ea3ce
--- /dev/null
+++ b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api;
+
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.net.ConnectException;
+import java.net.SocketTimeoutException;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link RetryExecutor}. */
+class RetryExecutorTest {
+
+    @Test
+    @DisplayName("Immediate success without retry")
+    void testImmediateSuccess() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    attempts.incrementAndGet();
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(1);
+    }
+
+    @Test
+    @DisplayName("Retry on SocketTimeoutException and succeed")
+    void testRetryOnSocketTimeout() {
+        RetryExecutor executor =
+                RetryExecutor.builder()
+                        .maxRetries(3)
+                        .initialBackoffMs(10)
+                        .maxBackoffMs(100)
+                        .build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt < 3) {
+                        throw new SocketTimeoutException("Connection timeout");
+                    }
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(3);
+    }
+
+    @Test
+    @DisplayName("Retry on ConnectException and succeed")
+    void testRetryOnConnectException() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt < 2) {
+                        throw new ConnectException("Connection refused");
+                    }
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(2);
+    }
+
+    @Test
+    @DisplayName("Retry on HTTP 503 Service Unavailable")
+    void testRetryOn503Error() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt == 1) {
+                        throw new RuntimeException("HTTP 503 Service 
Unavailable");
+                    }
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(2);
+    }
+
+    @Test
+    @DisplayName("Retry on HTTP 429 Too Many Requests")
+    void testRetryOn429Error() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt == 1) {
+                        throw new RuntimeException("HTTP 429 Too Many 
Requests");
+                    }
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(2);
+    }
+
+    @Test
+    @DisplayName("No retry on non-retryable exception")
+    void testNoRetryOnNonRetryableException() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    attempts.incrementAndGet();
+                    throw new IllegalArgumentException("Invalid input");
+                };
+
+        assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Operation 'testOperation' failed")
+                .hasMessageContaining("Invalid input");
+
+        // Should only try once (no retries for non-retryable)
+        assertThat(attempts.get()).isEqualTo(1);
+    }
+
+    @Test
+    @DisplayName("No retry on 4xx client errors")
+    void testNoRetryOn4xxError() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    attempts.incrementAndGet();
+                    throw new RuntimeException("HTTP 400 Bad Request");
+                };
+
+        assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Operation 'testOperation' failed")
+                .hasMessageContaining("400 Bad Request");
+
+        assertThat(attempts.get()).isEqualTo(1);
+    }
+
+    @Test
+    @DisplayName("Fail after max retries exceeded")
+    void testFailAfterMaxRetries() {
+        RetryExecutor executor = 
RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    attempts.incrementAndGet();
+                    throw new SocketTimeoutException("Always fails");
+                };
+
+        assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("failed after 2 retries");
+
+        // Should try initial attempt + 2 retries
+        assertThat(attempts.get()).isEqualTo(3);
+    }
+
+    @Test
+    @DisplayName("InterruptedException stops retry")
+    void testInterruptedExceptionStopsRetry() throws Exception {
+        RetryExecutor executor =
+                RetryExecutor.builder()
+                        .maxRetries(3)
+                        .initialBackoffMs(1000) // Long backoff
+                        .build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt == 1) {
+                        throw new SocketTimeoutException("Timeout");
+                    }
+                    return "success";
+                };
+
+        Thread testThread =
+                new Thread(
+                        () -> {
+                            try {
+                                executor.execute(operation, "testOperation");
+                            } catch (Exception e) {
+                                // Expected
+                            }
+                        });
+
+        testThread.start();
+        Thread.sleep(50);
+        testThread.interrupt();
+        testThread.join(2000);
+
+        assertThat(testThread.isAlive()).isFalse();
+        // The thread should have been interrupted before exhausting all 
retries
+        assertThat(attempts.get()).isLessThanOrEqualTo(3);
+    }
+
+    @Test
+    @DisplayName("Default configuration values")
+    void testDefaultConfiguration() {
+        RetryExecutor executor = RetryExecutor.withDefaults();
+
+        assertThat(executor.getMaxRetries()).isEqualTo(3);
+        assertThat(executor.getInitialBackoffMs()).isEqualTo(100);
+        assertThat(executor.getMaxBackoffMs()).isEqualTo(10000);
+    }
+
+    @Test
+    @DisplayName("Custom retryable predicate")
+    void testCustomRetryablePredicate() {
+        RetryExecutor executor =
+                RetryExecutor.builder()
+                        .maxRetries(2)
+                        .initialBackoffMs(10)
+                        .retryablePredicate(e -> e instanceof 
IllegalStateException)
+                        .build();
+
+        AtomicInteger attempts = new AtomicInteger(0);
+        Callable<String> operation =
+                () -> {
+                    int attempt = attempts.incrementAndGet();
+                    if (attempt < 2) {
+                        throw new IllegalStateException("Temporary error");
+                    }
+                    return "success";
+                };
+
+        String result = executor.execute(operation, "testOperation");
+
+        assertThat(result).isEqualTo("success");
+        assertThat(attempts.get()).isEqualTo(2);
+    }
+}
diff --git 
a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
 
b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
index 64e71a32..faa03411 100644
--- 
a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
+++ 
b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
@@ -26,6 +26,7 @@ import io.modelcontextprotocol.client.McpClient;
 import io.modelcontextprotocol.client.McpSyncClient;
 import 
io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
 import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.flink.agents.api.RetryExecutor;
 import org.apache.flink.agents.api.chat.messages.ChatMessage;
 import org.apache.flink.agents.api.chat.messages.MessageRole;
 import org.apache.flink.agents.api.resource.Resource;
@@ -81,6 +82,11 @@ public class MCPServer extends Resource {
     private static final String FIELD_HEADERS = "headers";
     private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds";
     private static final String FIELD_AUTH = "auth";
+    private static final String FIELD_MAX_RETRIES = "maxRetries";
+    private static final String FIELD_INITIAL_BACKOFF_MS = "initialBackoffMs";
+    private static final String FIELD_MAX_BACKOFF_MS = "maxBackoffMs";
+
+    private static final long DEFAULT_TIMEOUT_VALUE = 30L;
 
     @JsonProperty(FIELD_ENDPOINT)
     private final String endpoint;
@@ -94,14 +100,25 @@ public class MCPServer extends Resource {
     @JsonProperty(FIELD_AUTH)
     private final Auth auth;
 
+    private final Integer maxRetries;
+
+    private final Long initialBackoffMs;
+
+    private final Long maxBackoffMs;
+
+    @JsonIgnore private transient RetryExecutor retryExecutor;
+
     @JsonIgnore private transient McpSyncClient client;
 
     /** Builder for MCPServer with fluent API. */
     public static class Builder {
         private String endpoint;
         private final Map<String, String> headers = new HashMap<>();
-        private long timeoutSeconds = 30;
+        private long timeoutSeconds = DEFAULT_TIMEOUT_VALUE;
         private Auth auth = null;
+        private Integer maxRetries;
+        private Long initialBackoffMs;
+        private Long maxBackoffMs;
 
         public Builder endpoint(String endpoint) {
             this.endpoint = endpoint;
@@ -128,8 +145,30 @@ public class MCPServer extends Resource {
             return this;
         }
 
+        public Builder maxRetries(int maxRetries) {
+            this.maxRetries = maxRetries;
+            return this;
+        }
+
+        public Builder initialBackoff(Duration backoff) {
+            this.initialBackoffMs = backoff.toMillis();
+            return this;
+        }
+
+        public Builder maxBackoff(Duration backoff) {
+            this.maxBackoffMs = backoff.toMillis();
+            return this;
+        }
+
         public MCPServer build() {
-            return new MCPServer(endpoint, headers, timeoutSeconds, auth);
+            return new MCPServer(
+                    endpoint,
+                    headers,
+                    timeoutSeconds,
+                    auth,
+                    maxRetries,
+                    initialBackoffMs,
+                    maxBackoffMs);
         }
     }
 
@@ -138,11 +177,29 @@ public class MCPServer extends Resource {
         super(descriptor, getResource);
         this.endpoint =
                 Objects.requireNonNull(
-                        descriptor.getArgument("endpoint"), "endpoint cannot 
be null");
-        Map<String, String> headers = descriptor.getArgument("headers");
+                        descriptor.getArgument(FIELD_ENDPOINT), "endpoint 
cannot be null");
+        Map<String, String> headers = descriptor.getArgument(FIELD_HEADERS);
         this.headers = headers != null ? new HashMap<>(headers) : new 
HashMap<>();
-        this.timeoutSeconds = (int) descriptor.getArgument("timeout");
-        this.auth = descriptor.getArgument("auth");
+        Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT_SECONDS);
+        this.timeoutSeconds =
+                timeoutArg instanceof Number
+                        ? ((Number) timeoutArg).longValue()
+                        : DEFAULT_TIMEOUT_VALUE;
+        this.auth = descriptor.getArgument(FIELD_AUTH);
+
+        Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES);
+        this.maxRetries =
+                maxRetriesArg instanceof Number ? ((Number) 
maxRetriesArg).intValue() : null;
+
+        Object initialBackoffArg = 
descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS);
+        this.initialBackoffMs =
+                initialBackoffArg instanceof Number
+                        ? ((Number) initialBackoffArg).longValue()
+                        : null;
+
+        Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS);
+        this.maxBackoffMs =
+                maxBackoffArg instanceof Number ? ((Number) 
maxBackoffArg).longValue() : null;
     }
 
     /**
@@ -151,19 +208,25 @@ public class MCPServer extends Resource {
      * @param endpoint The HTTP endpoint of the MCP server
      */
     public MCPServer(String endpoint) {
-        this(endpoint, new HashMap<>(), 30, null);
+        this(endpoint, new HashMap<>(), DEFAULT_TIMEOUT_VALUE, null, null, 
null, null);
     }
 
     @JsonCreator
     public MCPServer(
             @JsonProperty(FIELD_ENDPOINT) String endpoint,
             @JsonProperty(FIELD_HEADERS) Map<String, String> headers,
-            @JsonProperty(FIELD_TIMEOUT_SECONDS) long timeoutSeconds,
-            @JsonProperty(FIELD_AUTH) Auth auth) {
+            @JsonProperty(FIELD_TIMEOUT_SECONDS) Long timeoutSeconds,
+            @JsonProperty(FIELD_AUTH) Auth auth,
+            @JsonProperty(FIELD_MAX_RETRIES) Integer maxRetries,
+            @JsonProperty(FIELD_INITIAL_BACKOFF_MS) Long initialBackoffMs,
+            @JsonProperty(FIELD_MAX_BACKOFF_MS) Long maxBackoffMs) {
         this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be 
null");
         this.headers = headers != null ? new HashMap<>(headers) : new 
HashMap<>();
-        this.timeoutSeconds = timeoutSeconds;
+        this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : 
DEFAULT_TIMEOUT_VALUE;
         this.auth = auth;
+        this.maxRetries = maxRetries;
+        this.initialBackoffMs = initialBackoffMs;
+        this.maxBackoffMs = maxBackoffMs;
     }
 
     public static Builder builder(String endpoint) {
@@ -192,6 +255,46 @@ public class MCPServer extends Resource {
         return auth;
     }
 
+    @JsonProperty(FIELD_MAX_RETRIES)
+    public int getMaxRetries() {
+        return maxRetries != null ? maxRetries : 
getRetryExecutor().getMaxRetries();
+    }
+
+    @JsonProperty(FIELD_INITIAL_BACKOFF_MS)
+    public long getInitialBackoffMs() {
+        return initialBackoffMs != null
+                ? initialBackoffMs
+                : getRetryExecutor().getInitialBackoffMs();
+    }
+
+    @JsonProperty(FIELD_MAX_BACKOFF_MS)
+    public long getMaxBackoffMs() {
+        return maxBackoffMs != null ? maxBackoffMs : 
getRetryExecutor().getMaxBackoffMs();
+    }
+
+    /**
+     * Get or create the retry executor.
+     *
+     * @return The retry executor
+     */
+    @JsonIgnore
+    private synchronized RetryExecutor getRetryExecutor() {
+        if (retryExecutor == null) {
+            RetryExecutor.Builder builder = RetryExecutor.builder();
+            if (maxRetries != null) {
+                builder.maxRetries(maxRetries);
+            }
+            if (initialBackoffMs != null) {
+                builder.initialBackoffMs(initialBackoffMs);
+            }
+            if (maxBackoffMs != null) {
+                builder.maxBackoffMs(maxBackoffMs);
+            }
+            retryExecutor = builder.build();
+        }
+        return retryExecutor;
+    }
+
     /**
      * Get or create a synchronized MCP client.
      *
@@ -200,7 +303,7 @@ public class MCPServer extends Resource {
     @JsonIgnore
     private synchronized McpSyncClient getClient() {
         if (client == null) {
-            client = createClient();
+            client = getRetryExecutor().execute(this::createClient, 
"createClient");
         }
         return client;
     }
@@ -263,22 +366,29 @@ public class MCPServer extends Resource {
      * @return List of MCPTool instances
      */
     public List<MCPTool> listTools() {
-        McpSyncClient mcpClient = getClient();
-        McpSchema.ListToolsResult toolsResult = mcpClient.listTools();
-
-        List<MCPTool> tools = new ArrayList<>();
-        for (McpSchema.Tool toolData : toolsResult.tools()) {
-            ToolMetadata metadata =
-                    new ToolMetadata(
-                            toolData.name(),
-                            toolData.description() != null ? 
toolData.description() : "",
-                            serializeInputSchema(toolData.inputSchema()));
-
-            MCPTool tool = new MCPTool(metadata, this);
-            tools.add(tool);
-        }
-
-        return tools;
+        return getRetryExecutor()
+                .execute(
+                        () -> {
+                            McpSyncClient mcpClient = getClient();
+                            McpSchema.ListToolsResult toolsResult = 
mcpClient.listTools();
+
+                            List<MCPTool> tools = new ArrayList<>();
+                            for (McpSchema.Tool toolData : 
toolsResult.tools()) {
+                                ToolMetadata metadata =
+                                        new ToolMetadata(
+                                                toolData.name(),
+                                                toolData.description() != null
+                                                        ? 
toolData.description()
+                                                        : "",
+                                                
serializeInputSchema(toolData.inputSchema()));
+
+                                MCPTool tool = new MCPTool(metadata, this);
+                                tools.add(tool);
+                            }
+
+                            return tools;
+                        },
+                        "listTools");
     }
 
     /**
@@ -320,18 +430,24 @@ public class MCPServer extends Resource {
      * @return The result as a list of content items
      */
     public List<Object> callTool(String toolName, Map<String, Object> 
arguments) {
-        McpSyncClient mcpClient = getClient();
-        McpSchema.CallToolRequest request =
-                new McpSchema.CallToolRequest(
-                        toolName, arguments != null ? arguments : new 
HashMap<>());
-        McpSchema.CallToolResult result = mcpClient.callTool(request);
-
-        List<Object> content = new ArrayList<>();
-        for (var item : result.content()) {
-            content.add(MCPContentExtractor.extractContentItem(item));
-        }
-
-        return content;
+        return getRetryExecutor()
+                .execute(
+                        () -> {
+                            McpSyncClient mcpClient = getClient();
+                            McpSchema.CallToolRequest request =
+                                    new McpSchema.CallToolRequest(
+                                            toolName,
+                                            arguments != null ? arguments : 
new HashMap<>());
+                            McpSchema.CallToolResult result = 
mcpClient.callTool(request);
+
+                            List<Object> content = new ArrayList<>();
+                            for (var item : result.content()) {
+                                
content.add(MCPContentExtractor.extractContentItem(item));
+                            }
+
+                            return content;
+                        },
+                        "callTool:" + toolName);
     }
 
     /**
@@ -340,27 +456,39 @@ public class MCPServer extends Resource {
      * @return List of MCPPrompt instances
      */
     public List<MCPPrompt> listPrompts() {
-        McpSyncClient mcpClient = getClient();
-        McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts();
-
-        List<MCPPrompt> prompts = new ArrayList<>();
-        for (McpSchema.Prompt promptData : promptsResult.prompts()) {
-            Map<String, MCPPrompt.PromptArgument> argumentsMap = new 
HashMap<>();
-            if (promptData.arguments() != null) {
-                for (var arg : promptData.arguments()) {
-                    argumentsMap.put(
-                            arg.name(),
-                            new MCPPrompt.PromptArgument(
-                                    arg.name(), arg.description(), 
arg.required()));
-                }
-            }
-
-            MCPPrompt prompt =
-                    new MCPPrompt(promptData.name(), promptData.description(), 
argumentsMap, this);
-            prompts.add(prompt);
-        }
-
-        return prompts;
+        return getRetryExecutor()
+                .execute(
+                        () -> {
+                            McpSyncClient mcpClient = getClient();
+                            McpSchema.ListPromptsResult promptsResult = 
mcpClient.listPrompts();
+
+                            List<MCPPrompt> prompts = new ArrayList<>();
+                            for (McpSchema.Prompt promptData : 
promptsResult.prompts()) {
+                                Map<String, MCPPrompt.PromptArgument> 
argumentsMap =
+                                        new HashMap<>();
+                                if (promptData.arguments() != null) {
+                                    for (var arg : promptData.arguments()) {
+                                        argumentsMap.put(
+                                                arg.name(),
+                                                new MCPPrompt.PromptArgument(
+                                                        arg.name(),
+                                                        arg.description(),
+                                                        arg.required()));
+                                    }
+                                }
+
+                                MCPPrompt prompt =
+                                        new MCPPrompt(
+                                                promptData.name(),
+                                                promptData.description(),
+                                                argumentsMap,
+                                                this);
+                                prompts.add(prompt);
+                            }
+
+                            return prompts;
+                        },
+                        "listPrompts");
     }
 
     /**
@@ -371,22 +499,29 @@ public class MCPServer extends Resource {
      * @return List of chat messages
      */
     public List<ChatMessage> getPrompt(String name, Map<String, Object> 
arguments) {
-        McpSyncClient mcpClient = getClient();
-        McpSchema.GetPromptRequest request =
-                new McpSchema.GetPromptRequest(
-                        name, arguments != null ? arguments : new HashMap<>());
-        McpSchema.GetPromptResult result = mcpClient.getPrompt(request);
-
-        List<ChatMessage> chatMessages = new ArrayList<>();
-        for (var message : result.messages()) {
-            if (message.content() instanceof McpSchema.TextContent) {
-                var textContent = (McpSchema.TextContent) message.content();
-                MessageRole role = 
MessageRole.valueOf(message.role().name().toUpperCase());
-                chatMessages.add(new ChatMessage(role, textContent.text()));
-            }
-        }
-
-        return chatMessages;
+        return getRetryExecutor()
+                .execute(
+                        () -> {
+                            McpSyncClient mcpClient = getClient();
+                            McpSchema.GetPromptRequest request =
+                                    new McpSchema.GetPromptRequest(
+                                            name, arguments != null ? 
arguments : new HashMap<>());
+                            McpSchema.GetPromptResult result = 
mcpClient.getPrompt(request);
+
+                            List<ChatMessage> chatMessages = new ArrayList<>();
+                            for (var message : result.messages()) {
+                                if (message.content() instanceof 
McpSchema.TextContent) {
+                                    var textContent = (McpSchema.TextContent) 
message.content();
+                                    MessageRole role =
+                                            MessageRole.valueOf(
+                                                    
message.role().name().toUpperCase());
+                                    chatMessages.add(new ChatMessage(role, 
textContent.text()));
+                                }
+                            }
+
+                            return chatMessages;
+                        },
+                        "getPrompt:" + name);
     }
 
     /** Close the MCP client and clean up resources. */
@@ -420,6 +555,9 @@ public class MCPServer extends Resource {
         if (o == null || getClass() != o.getClass()) return false;
         MCPServer that = (MCPServer) o;
         return timeoutSeconds == that.timeoutSeconds
+                && getMaxRetries() == that.getMaxRetries()
+                && getInitialBackoffMs() == that.getInitialBackoffMs()
+                && getMaxBackoffMs() == that.getMaxBackoffMs()
                 && Objects.equals(endpoint, that.endpoint)
                 && Objects.equals(headers, that.headers)
                 && Objects.equals(auth, that.auth);
@@ -427,7 +565,14 @@ public class MCPServer extends Resource {
 
     @Override
     public int hashCode() {
-        return Objects.hash(endpoint, headers, timeoutSeconds, auth);
+        return Objects.hash(
+                endpoint,
+                headers,
+                timeoutSeconds,
+                auth,
+                getMaxRetries(),
+                getInitialBackoffMs(),
+                getMaxBackoffMs());
     }
 
     @Override
diff --git 
a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
 
b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
index 5fe52bcc..85300386 100644
--- 
a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
+++ 
b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
@@ -243,4 +243,31 @@ class MCPServerTest {
         server.close();
         server.close(); // Calling twice should be safe
     }
+
+    @Test
+    @DisabledOnJre(JRE.JAVA_11)
+    @DisplayName("Default retry configuration")
+    void testDefaultRetryConfiguration() {
+        MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).build();
+
+        assertThat(server.getMaxRetries()).isEqualTo(3);
+        assertThat(server.getInitialBackoffMs()).isEqualTo(100);
+        assertThat(server.getMaxBackoffMs()).isEqualTo(10000);
+    }
+
+    @Test
+    @DisabledOnJre(JRE.JAVA_11)
+    @DisplayName("Custom retry configuration via builder")
+    void testCustomRetryConfiguration() {
+        MCPServer server =
+                MCPServer.builder(DEFAULT_ENDPOINT)
+                        .maxRetries(5)
+                        .initialBackoff(Duration.ofMillis(200))
+                        .maxBackoff(Duration.ofMillis(5000))
+                        .build();
+
+        assertThat(server.getMaxRetries()).isEqualTo(5);
+        assertThat(server.getInitialBackoffMs()).isEqualTo(200);
+        assertThat(server.getMaxBackoffMs()).isEqualTo(5000);
+    }
 }


Reply via email to