This is an automated email from the ASF dual-hosted git repository.

sxnan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit bc5b5e62e66f1dd4966528b69327b6b756664734
Author: sxnan <[email protected]>
AuthorDate: Mon Jan 26 14:49:26 2026 +0800

    [test] Fix and add more test for durable execution
---
 .../runtime/tests/test_durable_execution.py        |  69 +++++
 .../context/DurableExecutionContextTest.java       | 105 ++++++++
 .../operator/ActionExecutionOperatorTest.java      | 298 +++++++++++++++++++++
 3 files changed, 472 insertions(+)

diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py 
b/python/flink_agents/runtime/tests/test_durable_execution.py
index e59e54cd..080ad846 100644
--- a/python/flink_agents/runtime/tests/test_durable_execution.py
+++ b/python/flink_agents/runtime/tests/test_durable_execution.py
@@ -148,3 +148,72 @@ def test_cloudpickle_serialization() -> None:
         assert str(deserialized_exc) == "test error"
         assert isinstance(deserialized_exc, ValueError)
 
+
+def test_cloudpickle_exception_roundtrip_various_types() -> None:
+    """Test that various exception types can be serialized and deserialized."""
+    # Test various exception types that might occur in durable execution
+    test_exceptions = [
+        ValueError("Invalid value: 42"),
+        RuntimeError("Connection timeout"),
+        TypeError("Expected int, got str"),
+        KeyError("missing_key"),
+        AttributeError("Object has no attribute 'foo'"),
+        Exception("Generic exception with special chars: \"quotes\" and 
'apostrophes'"),
+    ]
+
+    for original_exc in test_exceptions:
+        serialized = cloudpickle.dumps(original_exc)
+        deserialized = cloudpickle.loads(serialized)
+
+        assert type(deserialized) is type(original_exc), (
+            f"Exception type mismatch: expected {type(original_exc)}, "
+            f"got {type(deserialized)}"
+        )
+        assert str(deserialized) == str(original_exc), (
+            f"Exception message mismatch: expected '{original_exc}', "
+            f"got '{deserialized}'"
+        )
+
+
+def test_cloudpickle_exception_with_custom_attributes() -> None:
+    """Test exceptions with custom attributes set after construction."""
+    # Create a standard exception and add custom attributes
+    original = RuntimeError("API error occurred")
+    original.error_code = 500
+    original.details = {"endpoint": "/api/chat", "retry_count": 3}
+
+    serialized = cloudpickle.dumps(original)
+    deserialized = cloudpickle.loads(serialized)
+
+    assert isinstance(deserialized, RuntimeError)
+    assert str(deserialized) == "API error occurred"
+    assert deserialized.error_code == 500
+    assert deserialized.details == {"endpoint": "/api/chat", "retry_count": 3}
+
+
+def test_cloudpickle_exception_basic_types_preserved() -> None:
+    """Test that common exception types are preserved through serialization."""
+    # Test that the exception type and message are preserved
+    # Note: cloudpickle may not preserve __cause__ chains
+
+    original = ValueError("Test value error")
+    serialized = cloudpickle.dumps(original)
+    deserialized = cloudpickle.loads(serialized)
+
+    assert isinstance(deserialized, ValueError)
+    assert str(deserialized) == "Test value error"
+    assert type(deserialized).__name__ == "ValueError"
+
+
+def test_cloudpickle_none_exception_message() -> None:
+    """Test that exceptions with None message can be serialized."""
+    # Some exceptions might have None as their message
+    original = RuntimeError(None)
+
+    serialized = cloudpickle.dumps(original)
+    deserialized = cloudpickle.loads(serialized)
+
+    assert isinstance(deserialized, RuntimeError)
+    # str() of an exception with None message is "None"
+    assert str(deserialized) == "None"
+
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
index f2701e50..9e7c2de0 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.agents.runtime.context;
 
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.actions.Action;
 import org.apache.flink.agents.runtime.actionstate.ActionState;
@@ -203,4 +205,107 @@ class DurableExecutionContextTest {
         assertEquals(3, persistCallCount.get());
         assertEquals(3, actionState.getCallResults().size());
     }
+
+    // ==================== DurableExecutionException Tests 
====================
+
+    @Test
+    void testDurableExecutionExceptionSerialization() throws Exception {
+        // Create exception
+        RuntimeException original = new RuntimeException("Test error message");
+        RunnerContextImpl.DurableExecutionException durableException =
+                
RunnerContextImpl.DurableExecutionException.fromException(original);
+
+        // Serialize to JSON
+        ObjectMapper mapper = new ObjectMapper();
+        String json = mapper.writeValueAsString(durableException);
+
+        // Verify JSON field names are semantically correct
+        JsonNode node = mapper.readTree(json);
+        assertEquals(
+                "java.lang.RuntimeException",
+                node.get("exceptionClass").asText(),
+                "JSON field 'exceptionClass' should contain the exception 
class name");
+        assertEquals(
+                "Test error message",
+                node.get("message").asText(),
+                "JSON field 'message' should contain the error message");
+    }
+
+    @Test
+    void testDurableExecutionExceptionDeserialization() throws Exception {
+        // Create and serialize exception
+        IllegalArgumentException original = new 
IllegalArgumentException("Invalid argument: foo");
+        RunnerContextImpl.DurableExecutionException durableException =
+                
RunnerContextImpl.DurableExecutionException.fromException(original);
+
+        ObjectMapper mapper = new ObjectMapper();
+        byte[] serialized = mapper.writeValueAsBytes(durableException);
+
+        // Deserialize
+        RunnerContextImpl.DurableExecutionException deserialized =
+                mapper.readValue(serialized, 
RunnerContextImpl.DurableExecutionException.class);
+
+        // Convert back to exception and verify content
+        Exception recovered = deserialized.toException();
+        assertTrue(
+                recovered.getMessage().contains("IllegalArgumentException"),
+                "Recovered exception should contain original class name");
+        assertTrue(
+                recovered.getMessage().contains("Invalid argument: foo"),
+                "Recovered exception should contain original message");
+    }
+
+    @Test
+    void testDurableExecutionExceptionRoundTrip() throws Exception {
+        // Test various exception types
+        Exception[] testExceptions = {
+            new RuntimeException("Runtime error"),
+            new IllegalStateException("Illegal state"),
+            new NullPointerException("Null value"),
+            new RuntimeException("Message with special chars: \"quotes\" and 
'apostrophes'"),
+            new RuntimeException("") // Empty message
+        };
+
+        ObjectMapper mapper = new ObjectMapper();
+
+        for (Exception original : testExceptions) {
+            // Create DurableExecutionException
+            RunnerContextImpl.DurableExecutionException durableException =
+                    
RunnerContextImpl.DurableExecutionException.fromException(original);
+
+            // Serialize and deserialize
+            byte[] serialized = mapper.writeValueAsBytes(durableException);
+            RunnerContextImpl.DurableExecutionException deserialized =
+                    mapper.readValue(serialized, 
RunnerContextImpl.DurableExecutionException.class);
+
+            // Verify round-trip
+            Exception recovered = deserialized.toException();
+            assertTrue(
+                    
recovered.getMessage().contains(original.getClass().getName()),
+                    "Recovered exception should contain class: " + 
original.getClass().getName());
+            if (original.getMessage() != null && 
!original.getMessage().isEmpty()) {
+                assertTrue(
+                        recovered.getMessage().contains(original.getMessage()),
+                        "Recovered exception should contain message: " + 
original.getMessage());
+            }
+        }
+    }
+
+    @Test
+    void testDurableExecutionExceptionWithNullMessage() throws Exception {
+        // Create exception with null message
+        RuntimeException original = new RuntimeException((String) null);
+        RunnerContextImpl.DurableExecutionException durableException =
+                
RunnerContextImpl.DurableExecutionException.fromException(original);
+
+        ObjectMapper mapper = new ObjectMapper();
+        byte[] serialized = mapper.writeValueAsBytes(durableException);
+
+        // Should not throw during serialization/deserialization
+        RunnerContextImpl.DurableExecutionException deserialized =
+                mapper.readValue(serialized, 
RunnerContextImpl.DurableExecutionException.class);
+
+        Exception recovered = deserialized.toException();
+        assertTrue(recovered.getMessage().contains("RuntimeException"));
+    }
 }
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index f2709b32..646d1e63 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -785,6 +785,180 @@ public class ActionExecutionOperatorTest {
         }
     }
 
+    /**
+     * Tests that durableExecute exception can be serialized and recovered 
correctly when the action
+     * does NOT catch the exception (simulates built-in action behavior like 
ChatModelAction).
+     *
+     * <p>This test verifies that:
+     *
+     * <ul>
+     *   <li>DurableExecutionException can be properly serialized by Jackson
+     *   <li>On recovery, the cached exception is re-thrown without 
re-executing the supplier
+     *   <li>The exception content (class name and message) is preserved
+     * </ul>
+     */
+    @Test
+    void testDurableExecuteExceptionRecoveryWithUncaughtException() throws 
Exception {
+        AgentPlan agentPlan = TestAgent.getDurableExceptionUncaughtAgentPlan();
+        InMemoryActionStateStore actionStateStore = new 
InMemoryActionStateStore(false);
+
+        // Reset counter
+        TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.set(0);
+
+        String firstExecutionExceptionChain = null;
+
+        // First execution - will execute the supplier, throw exception, and 
store it
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(1L));
+
+            // This should throw because the exception is not caught in the 
action
+            try {
+                operator.waitInFlightEventsFinished();
+            } catch (Exception e) {
+                // Collect all exception messages in the chain
+                firstExecutionExceptionChain = 
ExceptionUtils.stringifyException(e);
+            }
+        }
+
+        // Verify supplier was called once
+        
assertThat(TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.get()).isEqualTo(1);
+
+        // Verify exception was thrown and contains correct info somewhere in 
the chain
+        assertThat(firstExecutionExceptionChain).isNotNull();
+        assertThat(firstExecutionExceptionChain)
+                .as("Exception chain should contain original class name")
+                .contains("IllegalStateException");
+        assertThat(firstExecutionExceptionChain)
+                .as("Exception chain should contain original message")
+                .contains("Simulated LLM failure");
+
+        // Verify action state was stored with call result
+        assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+
+        String recoveryExceptionChain = null;
+
+        // Second execution - should recover cached exception without calling 
supplier
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(1L));
+
+            try {
+                operator.waitInFlightEventsFinished();
+            } catch (Exception e) {
+                // Collect all exception messages in the chain
+                recoveryExceptionChain = ExceptionUtils.stringifyException(e);
+            }
+        }
+
+        // CRITICAL: Verify supplier was NOT called during recovery
+        assertThat(TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.get())
+                .as("Supplier should NOT be called during exception recovery")
+                .isEqualTo(1);
+
+        // Verify recovered exception contains correct information in the chain
+        assertThat(recoveryExceptionChain).isNotNull();
+        assertThat(recoveryExceptionChain)
+                .as("Recovered exception chain should contain original class 
name")
+                .contains("IllegalStateException");
+        assertThat(recoveryExceptionChain)
+                .as("Recovered exception chain should contain original 
message")
+                .contains("Simulated LLM failure");
+    }
+
+    /**
+     * Tests that durableExecuteAsync exception can be serialized and 
recovered correctly.
+     *
+     * <p>This test verifies async exception handling works the same way as 
sync.
+     */
+    @Test
+    void testDurableExecuteAsyncExceptionRecovery() throws Exception {
+        AgentPlan agentPlan = TestAgent.getDurableAsyncExceptionAgentPlan();
+        InMemoryActionStateStore actionStateStore = new 
InMemoryActionStateStore(false);
+
+        // Reset counter
+        TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.set(0);
+
+        String firstExecutionExceptionChain = null;
+
+        // First execution - will execute the async supplier, throw exception, 
and store it
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(1L));
+
+            try {
+                operator.waitInFlightEventsFinished();
+            } catch (Exception e) {
+                firstExecutionExceptionChain = 
ExceptionUtils.stringifyException(e);
+            }
+        }
+
+        // Verify supplier was called once
+        assertThat(TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.get()).isEqualTo(1);
+
+        // Verify exception was thrown
+        assertThat(firstExecutionExceptionChain).isNotNull();
+        assertThat(firstExecutionExceptionChain)
+                .as("Exception chain should contain original message")
+                .contains("Async operation failed");
+
+        // Verify action state was stored
+        assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty();
+
+        String recoveryExceptionChain = null;
+
+        // Second execution - should recover cached exception without calling 
supplier
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new ActionExecutionOperatorFactory<>(agentPlan, true, 
actionStateStore),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            testHarness.processElement(new StreamRecord<>(1L));
+
+            try {
+                operator.waitInFlightEventsFinished();
+            } catch (Exception e) {
+                recoveryExceptionChain = ExceptionUtils.stringifyException(e);
+            }
+        }
+
+        // CRITICAL: Verify supplier was NOT called during recovery
+        assertThat(TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.get())
+                .as("Supplier should NOT be called during async exception 
recovery")
+                .isEqualTo(1);
+
+        // Verify recovered exception contains correct information
+        assertThat(recoveryExceptionChain).isNotNull();
+        assertThat(recoveryExceptionChain)
+                .as("Recovered exception chain should contain original 
message")
+                .contains("Async operation failed");
+    }
+
     public static class TestAgent {
 
         /** Counter to track how many times the durable supplier is executed. 
*/
@@ -1154,6 +1328,130 @@ public class ActionExecutionOperatorTest {
             }
             return null;
         }
+
+        // ==================== Actions for Exception Recovery Tests 
====================
+
+        /**
+         * Counter to track how many times the uncaught exception supplier is 
executed. Used to
+         * verify that on recovery, the supplier is not re-executed.
+         */
+        public static final java.util.concurrent.atomic.AtomicInteger
+                UNCAUGHT_EXCEPTION_CALL_COUNTER = new 
java.util.concurrent.atomic.AtomicInteger(0);
+
+        /**
+         * Action that uses durableExecute and does NOT catch the exception. 
This simulates the
+         * behavior of built-in actions like ChatModelAction.
+         */
+        public static void durableExceptionUncaughtAction(InputEvent event, 
RunnerContext context) {
+            try {
+                context.durableExecute(
+                        new DurableCallable<String>() {
+                            @Override
+                            public String getId() {
+                                return "uncaught-exception-action";
+                            }
+
+                            @Override
+                            public Class<String> getResultClass() {
+                                return String.class;
+                            }
+
+                            @Override
+                            public String call() {
+                                
UNCAUGHT_EXCEPTION_CALL_COUNTER.incrementAndGet();
+                                throw new IllegalStateException(
+                                        "Simulated LLM failure: Connection 
timeout");
+                            }
+                        });
+            } catch (Exception e) {
+                // Re-throw without wrapping - simulates built-in action 
behavior
+                ExceptionUtils.rethrow(e);
+            }
+        }
+
+        /**
+         * Counter to track how many times the async exception supplier is 
executed. Used to verify
+         * that on recovery, the supplier is not re-executed.
+         */
+        public static final java.util.concurrent.atomic.AtomicInteger 
ASYNC_EXCEPTION_CALL_COUNTER =
+                new java.util.concurrent.atomic.AtomicInteger(0);
+
+        /**
+         * Action that uses durableExecuteAsync and does NOT catch the 
exception. This simulates
+         * async operations that fail.
+         */
+        public static void durableAsyncExceptionAction(InputEvent event, 
RunnerContext context) {
+            try {
+                context.durableExecuteAsync(
+                        new DurableCallable<String>() {
+                            @Override
+                            public String getId() {
+                                return "async-exception-action";
+                            }
+
+                            @Override
+                            public Class<String> getResultClass() {
+                                return String.class;
+                            }
+
+                            @Override
+                            public String call() {
+                                ASYNC_EXCEPTION_CALL_COUNTER.incrementAndGet();
+                                throw new RuntimeException("Async operation 
failed: API error");
+                            }
+                        });
+            } catch (Exception e) {
+                ExceptionUtils.rethrow(e);
+            }
+        }
+
+        public static AgentPlan getDurableExceptionUncaughtAgentPlan() {
+            try {
+                Map<String, List<Action>> actionsByEvent = new HashMap<>();
+                Map<String, Action> actions = new HashMap<>();
+
+                Action exceptionAction =
+                        new Action(
+                                "durableExceptionUncaughtAction",
+                                new JavaFunction(
+                                        TestAgent.class,
+                                        "durableExceptionUncaughtAction",
+                                        new Class<?>[] {InputEvent.class, 
RunnerContext.class}),
+                                
Collections.singletonList(InputEvent.class.getName()));
+                actionsByEvent.put(
+                        InputEvent.class.getName(), 
Collections.singletonList(exceptionAction));
+                actions.put(exceptionAction.getName(), exceptionAction);
+
+                return new AgentPlan(actions, actionsByEvent, new HashMap<>());
+            } catch (Exception e) {
+                ExceptionUtils.rethrow(e);
+            }
+            return null;
+        }
+
+        public static AgentPlan getDurableAsyncExceptionAgentPlan() {
+            try {
+                Map<String, List<Action>> actionsByEvent = new HashMap<>();
+                Map<String, Action> actions = new HashMap<>();
+
+                Action exceptionAction =
+                        new Action(
+                                "durableAsyncExceptionAction",
+                                new JavaFunction(
+                                        TestAgent.class,
+                                        "durableAsyncExceptionAction",
+                                        new Class<?>[] {InputEvent.class, 
RunnerContext.class}),
+                                
Collections.singletonList(InputEvent.class.getName()));
+                actionsByEvent.put(
+                        InputEvent.class.getName(), 
Collections.singletonList(exceptionAction));
+                actions.put(exceptionAction.getName(), exceptionAction);
+
+                return new AgentPlan(actions, actionsByEvent, new HashMap<>());
+            } catch (Exception e) {
+                ExceptionUtils.rethrow(e);
+            }
+            return null;
+        }
     }
 
     private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int 
expectedSize)

Reply via email to