This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit bc5b5e62e66f1dd4966528b69327b6b756664734 Author: sxnan <[email protected]> AuthorDate: Mon Jan 26 14:49:26 2026 +0800 [test] Fix and add more test for durable execution --- .../runtime/tests/test_durable_execution.py | 69 +++++ .../context/DurableExecutionContextTest.java | 105 ++++++++ .../operator/ActionExecutionOperatorTest.java | 298 +++++++++++++++++++++ 3 files changed, 472 insertions(+) diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py index e59e54cd..080ad846 100644 --- a/python/flink_agents/runtime/tests/test_durable_execution.py +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -148,3 +148,72 @@ def test_cloudpickle_serialization() -> None: assert str(deserialized_exc) == "test error" assert isinstance(deserialized_exc, ValueError) + +def test_cloudpickle_exception_roundtrip_various_types() -> None: + """Test that various exception types can be serialized and deserialized.""" + # Test various exception types that might occur in durable execution + test_exceptions = [ + ValueError("Invalid value: 42"), + RuntimeError("Connection timeout"), + TypeError("Expected int, got str"), + KeyError("missing_key"), + AttributeError("Object has no attribute 'foo'"), + Exception("Generic exception with special chars: \"quotes\" and 'apostrophes'"), + ] + + for original_exc in test_exceptions: + serialized = cloudpickle.dumps(original_exc) + deserialized = cloudpickle.loads(serialized) + + assert type(deserialized) is type(original_exc), ( + f"Exception type mismatch: expected {type(original_exc)}, " + f"got {type(deserialized)}" + ) + assert str(deserialized) == str(original_exc), ( + f"Exception message mismatch: expected '{original_exc}', " + f"got '{deserialized}'" + ) + + +def test_cloudpickle_exception_with_custom_attributes() -> None: + """Test exceptions with custom attributes set after construction.""" + # Create a standard exception and add custom attributes + original = RuntimeError("API error occurred") + original.error_code = 500 + original.details = {"endpoint": "/api/chat", "retry_count": 3} + + serialized = cloudpickle.dumps(original) + deserialized = cloudpickle.loads(serialized) + + assert isinstance(deserialized, RuntimeError) + assert str(deserialized) == "API error occurred" + assert deserialized.error_code == 500 + assert deserialized.details == {"endpoint": "/api/chat", "retry_count": 3} + + +def test_cloudpickle_exception_basic_types_preserved() -> None: + """Test that common exception types are preserved through serialization.""" + # Test that the exception type and message are preserved + # Note: cloudpickle may not preserve __cause__ chains + + original = ValueError("Test value error") + serialized = cloudpickle.dumps(original) + deserialized = cloudpickle.loads(serialized) + + assert isinstance(deserialized, ValueError) + assert str(deserialized) == "Test value error" + assert type(deserialized).__name__ == "ValueError" + + +def test_cloudpickle_none_exception_message() -> None: + """Test that exceptions with None message can be serialized.""" + # Some exceptions might have None as their message + original = RuntimeError(None) + + serialized = cloudpickle.dumps(original) + deserialized = cloudpickle.loads(serialized) + + assert isinstance(deserialized, RuntimeError) + # str() of an exception with None message is "None" + assert str(deserialized) == "None" + diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java index f2701e50..9e7c2de0 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java @@ -18,6 +18,8 @@ package org.apache.flink.agents.runtime.context; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.runtime.actionstate.ActionState; @@ -203,4 +205,107 @@ class DurableExecutionContextTest { assertEquals(3, persistCallCount.get()); assertEquals(3, actionState.getCallResults().size()); } + + // ==================== DurableExecutionException Tests ==================== + + @Test + void testDurableExecutionExceptionSerialization() throws Exception { + // Create exception + RuntimeException original = new RuntimeException("Test error message"); + RunnerContextImpl.DurableExecutionException durableException = + RunnerContextImpl.DurableExecutionException.fromException(original); + + // Serialize to JSON + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(durableException); + + // Verify JSON field names are semantically correct + JsonNode node = mapper.readTree(json); + assertEquals( + "java.lang.RuntimeException", + node.get("exceptionClass").asText(), + "JSON field 'exceptionClass' should contain the exception class name"); + assertEquals( + "Test error message", + node.get("message").asText(), + "JSON field 'message' should contain the error message"); + } + + @Test + void testDurableExecutionExceptionDeserialization() throws Exception { + // Create and serialize exception + IllegalArgumentException original = new IllegalArgumentException("Invalid argument: foo"); + RunnerContextImpl.DurableExecutionException durableException = + RunnerContextImpl.DurableExecutionException.fromException(original); + + ObjectMapper mapper = new ObjectMapper(); + byte[] serialized = mapper.writeValueAsBytes(durableException); + + // Deserialize + RunnerContextImpl.DurableExecutionException deserialized = + mapper.readValue(serialized, RunnerContextImpl.DurableExecutionException.class); + + // Convert back to exception and verify content + Exception recovered = deserialized.toException(); + assertTrue( + recovered.getMessage().contains("IllegalArgumentException"), + "Recovered exception should contain original class name"); + assertTrue( + recovered.getMessage().contains("Invalid argument: foo"), + "Recovered exception should contain original message"); + } + + @Test + void testDurableExecutionExceptionRoundTrip() throws Exception { + // Test various exception types + Exception[] testExceptions = { + new RuntimeException("Runtime error"), + new IllegalStateException("Illegal state"), + new NullPointerException("Null value"), + new RuntimeException("Message with special chars: \"quotes\" and 'apostrophes'"), + new RuntimeException("") // Empty message + }; + + ObjectMapper mapper = new ObjectMapper(); + + for (Exception original : testExceptions) { + // Create DurableExecutionException + RunnerContextImpl.DurableExecutionException durableException = + RunnerContextImpl.DurableExecutionException.fromException(original); + + // Serialize and deserialize + byte[] serialized = mapper.writeValueAsBytes(durableException); + RunnerContextImpl.DurableExecutionException deserialized = + mapper.readValue(serialized, RunnerContextImpl.DurableExecutionException.class); + + // Verify round-trip + Exception recovered = deserialized.toException(); + assertTrue( + recovered.getMessage().contains(original.getClass().getName()), + "Recovered exception should contain class: " + original.getClass().getName()); + if (original.getMessage() != null && !original.getMessage().isEmpty()) { + assertTrue( + recovered.getMessage().contains(original.getMessage()), + "Recovered exception should contain message: " + original.getMessage()); + } + } + } + + @Test + void testDurableExecutionExceptionWithNullMessage() throws Exception { + // Create exception with null message + RuntimeException original = new RuntimeException((String) null); + RunnerContextImpl.DurableExecutionException durableException = + RunnerContextImpl.DurableExecutionException.fromException(original); + + ObjectMapper mapper = new ObjectMapper(); + byte[] serialized = mapper.writeValueAsBytes(durableException); + + // Should not throw during serialization/deserialization + RunnerContextImpl.DurableExecutionException deserialized = + mapper.readValue(serialized, RunnerContextImpl.DurableExecutionException.class); + + Exception recovered = deserialized.toException(); + assertTrue(recovered.getMessage().contains("RuntimeException")); + } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index f2709b32..646d1e63 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -785,6 +785,180 @@ public class ActionExecutionOperatorTest { } } + /** + * Tests that durableExecute exception can be serialized and recovered correctly when the action + * does NOT catch the exception (simulates built-in action behavior like ChatModelAction). + * + * <p>This test verifies that: + * + * <ul> + * <li>DurableExecutionException can be properly serialized by Jackson + * <li>On recovery, the cached exception is re-thrown without re-executing the supplier + * <li>The exception content (class name and message) is preserved + * </ul> + */ + @Test + void testDurableExecuteExceptionRecoveryWithUncaughtException() throws Exception { + AgentPlan agentPlan = TestAgent.getDurableExceptionUncaughtAgentPlan(); + InMemoryActionStateStore actionStateStore = new InMemoryActionStateStore(false); + + // Reset counter + TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.set(0); + + String firstExecutionExceptionChain = null; + + // First execution - will execute the supplier, throw exception, and store it + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(1L)); + + // This should throw because the exception is not caught in the action + try { + operator.waitInFlightEventsFinished(); + } catch (Exception e) { + // Collect all exception messages in the chain + firstExecutionExceptionChain = ExceptionUtils.stringifyException(e); + } + } + + // Verify supplier was called once + assertThat(TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.get()).isEqualTo(1); + + // Verify exception was thrown and contains correct info somewhere in the chain + assertThat(firstExecutionExceptionChain).isNotNull(); + assertThat(firstExecutionExceptionChain) + .as("Exception chain should contain original class name") + .contains("IllegalStateException"); + assertThat(firstExecutionExceptionChain) + .as("Exception chain should contain original message") + .contains("Simulated LLM failure"); + + // Verify action state was stored with call result + assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty(); + + String recoveryExceptionChain = null; + + // Second execution - should recover cached exception without calling supplier + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(1L)); + + try { + operator.waitInFlightEventsFinished(); + } catch (Exception e) { + // Collect all exception messages in the chain + recoveryExceptionChain = ExceptionUtils.stringifyException(e); + } + } + + // CRITICAL: Verify supplier was NOT called during recovery + assertThat(TestAgent.UNCAUGHT_EXCEPTION_CALL_COUNTER.get()) + .as("Supplier should NOT be called during exception recovery") + .isEqualTo(1); + + // Verify recovered exception contains correct information in the chain + assertThat(recoveryExceptionChain).isNotNull(); + assertThat(recoveryExceptionChain) + .as("Recovered exception chain should contain original class name") + .contains("IllegalStateException"); + assertThat(recoveryExceptionChain) + .as("Recovered exception chain should contain original message") + .contains("Simulated LLM failure"); + } + + /** + * Tests that durableExecuteAsync exception can be serialized and recovered correctly. + * + * <p>This test verifies async exception handling works the same way as sync. + */ + @Test + void testDurableExecuteAsyncExceptionRecovery() throws Exception { + AgentPlan agentPlan = TestAgent.getDurableAsyncExceptionAgentPlan(); + InMemoryActionStateStore actionStateStore = new InMemoryActionStateStore(false); + + // Reset counter + TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.set(0); + + String firstExecutionExceptionChain = null; + + // First execution - will execute the async supplier, throw exception, and store it + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(1L)); + + try { + operator.waitInFlightEventsFinished(); + } catch (Exception e) { + firstExecutionExceptionChain = ExceptionUtils.stringifyException(e); + } + } + + // Verify supplier was called once + assertThat(TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.get()).isEqualTo(1); + + // Verify exception was thrown + assertThat(firstExecutionExceptionChain).isNotNull(); + assertThat(firstExecutionExceptionChain) + .as("Exception chain should contain original message") + .contains("Async operation failed"); + + // Verify action state was stored + assertThat(actionStateStore.getKeyedActionStates()).isNotEmpty(); + + String recoveryExceptionChain = null; + + // Second execution - should recover cached exception without calling supplier + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>(agentPlan, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + testHarness.processElement(new StreamRecord<>(1L)); + + try { + operator.waitInFlightEventsFinished(); + } catch (Exception e) { + recoveryExceptionChain = ExceptionUtils.stringifyException(e); + } + } + + // CRITICAL: Verify supplier was NOT called during recovery + assertThat(TestAgent.ASYNC_EXCEPTION_CALL_COUNTER.get()) + .as("Supplier should NOT be called during async exception recovery") + .isEqualTo(1); + + // Verify recovered exception contains correct information + assertThat(recoveryExceptionChain).isNotNull(); + assertThat(recoveryExceptionChain) + .as("Recovered exception chain should contain original message") + .contains("Async operation failed"); + } + public static class TestAgent { /** Counter to track how many times the durable supplier is executed. */ @@ -1154,6 +1328,130 @@ public class ActionExecutionOperatorTest { } return null; } + + // ==================== Actions for Exception Recovery Tests ==================== + + /** + * Counter to track how many times the uncaught exception supplier is executed. Used to + * verify that on recovery, the supplier is not re-executed. + */ + public static final java.util.concurrent.atomic.AtomicInteger + UNCAUGHT_EXCEPTION_CALL_COUNTER = new java.util.concurrent.atomic.AtomicInteger(0); + + /** + * Action that uses durableExecute and does NOT catch the exception. This simulates the + * behavior of built-in actions like ChatModelAction. + */ + public static void durableExceptionUncaughtAction(InputEvent event, RunnerContext context) { + try { + context.durableExecute( + new DurableCallable<String>() { + @Override + public String getId() { + return "uncaught-exception-action"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + UNCAUGHT_EXCEPTION_CALL_COUNTER.incrementAndGet(); + throw new IllegalStateException( + "Simulated LLM failure: Connection timeout"); + } + }); + } catch (Exception e) { + // Re-throw without wrapping - simulates built-in action behavior + ExceptionUtils.rethrow(e); + } + } + + /** + * Counter to track how many times the async exception supplier is executed. Used to verify + * that on recovery, the supplier is not re-executed. + */ + public static final java.util.concurrent.atomic.AtomicInteger ASYNC_EXCEPTION_CALL_COUNTER = + new java.util.concurrent.atomic.AtomicInteger(0); + + /** + * Action that uses durableExecuteAsync and does NOT catch the exception. This simulates + * async operations that fail. + */ + public static void durableAsyncExceptionAction(InputEvent event, RunnerContext context) { + try { + context.durableExecuteAsync( + new DurableCallable<String>() { + @Override + public String getId() { + return "async-exception-action"; + } + + @Override + public Class<String> getResultClass() { + return String.class; + } + + @Override + public String call() { + ASYNC_EXCEPTION_CALL_COUNTER.incrementAndGet(); + throw new RuntimeException("Async operation failed: API error"); + } + }); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + public static AgentPlan getDurableExceptionUncaughtAgentPlan() { + try { + Map<String, List<Action>> actionsByEvent = new HashMap<>(); + Map<String, Action> actions = new HashMap<>(); + + Action exceptionAction = + new Action( + "durableExceptionUncaughtAction", + new JavaFunction( + TestAgent.class, + "durableExceptionUncaughtAction", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), Collections.singletonList(exceptionAction)); + actions.put(exceptionAction.getName(), exceptionAction); + + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + return null; + } + + public static AgentPlan getDurableAsyncExceptionAgentPlan() { + try { + Map<String, List<Action>> actionsByEvent = new HashMap<>(); + Map<String, Action> actions = new HashMap<>(); + + Action exceptionAction = + new Action( + "durableAsyncExceptionAction", + new JavaFunction( + TestAgent.class, + "durableAsyncExceptionAction", + new Class<?>[] {InputEvent.class, RunnerContext.class}), + Collections.singletonList(InputEvent.class.getName())); + actionsByEvent.put( + InputEvent.class.getName(), Collections.singletonList(exceptionAction)); + actions.put(exceptionAction.getName(), exceptionAction); + + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + return null; + } } private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int expectedSize)
