This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 15b1976087762f9c810c1ef37c4500f2a082d1f4 Author: sxnan <[email protected]> AuthorDate: Thu Jan 8 15:45:10 2026 +0800 [runtime] Introduce CallRecord to ActionState --- .../agents/runtime/actionstate/ActionState.java | 126 ++++++++- .../agents/runtime/actionstate/CallResult.java | 176 ++++++++++++ .../runtime/actionstate/ActionStateSerdeTest.java | 117 +++++++- .../runtime/actionstate/ActionStateTest.java | 310 +++++++++++++++++++++ .../agents/runtime/actionstate/CallResultTest.java | 157 +++++++++++ .../actionstate/KafkaActionStateStoreTest.java | 7 +- 6 files changed, 884 insertions(+), 9 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java index 34eefb35..031928ad 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java @@ -17,6 +17,7 @@ */ package org.apache.flink.agents.runtime.actionstate; +import com.fasterxml.jackson.annotation.JsonIgnore; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.context.MemoryUpdate; @@ -30,19 +31,32 @@ public class ActionState { private final List<MemoryUpdate> shortTermMemoryUpdates; private final List<Event> outputEvents; - /** Constructs a new TaskActionState instance. */ - public ActionState(final Event taskEvent) { - this.taskEvent = taskEvent; + /** + * Records of completed durable_execute/durable_execute_async calls for fine-grained recovery. + */ + private final List<CallResult> callResults; + + /** Indicates whether the action has completed execution. */ + private boolean completed; + + /** Default constructor for Jackson deserialization. */ + private ActionState() { + this.taskEvent = null; this.sensoryMemoryUpdates = new ArrayList<>(); this.shortTermMemoryUpdates = new ArrayList<>(); this.outputEvents = new ArrayList<>(); + this.callResults = new ArrayList<>(); + this.completed = false; } - public ActionState() { - this.taskEvent = null; + /** Constructs a new TaskActionState instance. */ + public ActionState(final Event taskEvent) { + this.taskEvent = taskEvent; this.sensoryMemoryUpdates = new ArrayList<>(); this.shortTermMemoryUpdates = new ArrayList<>(); this.outputEvents = new ArrayList<>(); + this.callResults = new ArrayList<>(); + this.completed = false; } /** Constructor for deserialization purposes. */ @@ -50,13 +64,17 @@ public class ActionState { Event taskEvent, List<MemoryUpdate> sensoryMemoryUpdates, List<MemoryUpdate> shortTermMemoryUpdates, - List<Event> outputEvents) { + List<Event> outputEvents, + List<CallResult> callResults, + boolean completed) { this.taskEvent = taskEvent; this.sensoryMemoryUpdates = sensoryMemoryUpdates != null ? sensoryMemoryUpdates : new ArrayList<>(); this.shortTermMemoryUpdates = shortTermMemoryUpdates != null ? shortTermMemoryUpdates : new ArrayList<>(); this.outputEvents = outputEvents != null ? outputEvents : new ArrayList<>(); + this.callResults = callResults != null ? callResults : new ArrayList<>(); + this.completed = completed; } /** Getters for the fields */ @@ -90,6 +108,77 @@ public class ActionState { outputEvents.add(event); } + /** Gets the list of call results for fine-grained durable execution. */ + public List<CallResult> getCallResults() { + return callResults; + } + + /** + * Adds a call result for a completed durable_execute/durable_execute_async call. + * + * @param callResult the call result to add + */ + public void addCallResult(CallResult callResult) { + callResults.add(callResult); + } + + /** + * Gets the call result at the specified index. + * + * @param index the index of the call result + * @return the call result at the specified index, or null if index is out of bounds + */ + public CallResult getCallResult(int index) { + if (index >= 0 && index < callResults.size()) { + return callResults.get(index); + } + return null; + } + + /** + * Gets the number of call results. + * + * @return the number of call results + */ + @JsonIgnore + public int getCallResultCount() { + return callResults.size(); + } + + /** + * Clears all call results. This should be called when the action completes to reduce storage + * overhead. + */ + public void clearCallResults() { + callResults.clear(); + } + + /** + * Clears call results from the specified index onwards. This is used when a non-deterministic + * call order is detected during recovery. + * + * @param fromIndex the index from which to clear results (inclusive) + */ + public void clearCallResultsFrom(int fromIndex) { + if (fromIndex >= 0 && fromIndex < callResults.size()) { + callResults.subList(fromIndex, callResults.size()).clear(); + } + } + + /** Returns whether the action has completed execution. */ + public boolean isCompleted() { + return completed; + } + + /** + * Marks the action as completed and clears call results. This should be called when the action + * finishes execution to indicate that recovery should skip the entire action. + */ + public void markCompleted() { + this.completed = true; + this.callResults.clear(); + } + @Override public int hashCode() { int result = taskEvent != null ? taskEvent.hashCode() : 0; @@ -102,12 +191,31 @@ public class ActionState { ? 0 : shortTermMemoryUpdates.hashCode()); result = 31 * result + (outputEvents.isEmpty() ? 0 : outputEvents.hashCode()); + result = 31 * result + (callResults.isEmpty() ? 0 : callResults.hashCode()); + result = 31 * result + (completed ? 1 : 0); return result; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ActionState that = (ActionState) o; + return completed == that.completed + && java.util.Objects.equals(taskEvent, that.taskEvent) + && java.util.Objects.equals(sensoryMemoryUpdates, that.sensoryMemoryUpdates) + && java.util.Objects.equals(shortTermMemoryUpdates, that.shortTermMemoryUpdates) + && java.util.Objects.equals(outputEvents, that.outputEvents) + && java.util.Objects.equals(callResults, that.callResults); + } + @Override public String toString() { - return "TaskActionState{" + return "ActionState{" + "taskEvent=" + taskEvent + ", sensoryMemoryUpdates=" @@ -116,6 +224,10 @@ public class ActionState { + shortTermMemoryUpdates + ", outputEvents=" + outputEvents + + ", callResults=" + + callResults + + ", completed=" + + completed + '}'; } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java new file mode 100644 index 00000000..cb9c5338 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallResult.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.Arrays; +import java.util.Objects; + +/** + * Represents a result of a function call execution for fine-grained durable execution. + * + * <p>This class stores the execution result of a single {@code durable_execute} or {@code + * durable_execute_async} call, enabling recovery without re-execution when the same call is + * encountered during job recovery. + * + * <p>During recovery, the success or failure of the original call is determined by checking whether + * {@code exceptionPayload} is null. + */ +public class CallResult { + + /** Function identifier: module+qualname for Python, or method signature for Java. */ + private final String functionId; + + /** Stable digest of the serialized arguments for validation during recovery. */ + private final String argsDigest; + + /** Serialized return value of the function call (null if the call threw an exception). */ + private final byte[] resultPayload; + + /** Serialized exception info if the call failed (null if the call succeeded). */ + private final byte[] exceptionPayload; + + /** Default constructor for deserialization. */ + public CallResult() { + this.functionId = null; + this.argsDigest = null; + this.resultPayload = null; + this.exceptionPayload = null; + } + + /** + * Constructs a CallResult for a successful function call. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized return value + */ + public CallResult(String functionId, String argsDigest, byte[] resultPayload) { + this.functionId = functionId; + this.argsDigest = argsDigest; + this.resultPayload = resultPayload; + this.exceptionPayload = null; + } + + /** + * Constructs a CallResult with explicit result and exception payloads. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized return value (null if exception occurred) + * @param exceptionPayload the serialized exception (null if call succeeded) + */ + public CallResult( + String functionId, String argsDigest, byte[] resultPayload, byte[] exceptionPayload) { + this.functionId = functionId; + this.argsDigest = argsDigest; + this.resultPayload = resultPayload; + this.exceptionPayload = exceptionPayload; + } + + /** + * Creates a CallResult for a failed function call. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param exceptionPayload the serialized exception + * @return a new CallResult representing a failed call + */ + public static CallResult ofException( + String functionId, String argsDigest, byte[] exceptionPayload) { + return new CallResult(functionId, argsDigest, null, exceptionPayload); + } + + public String getFunctionId() { + return functionId; + } + + public String getArgsDigest() { + return argsDigest; + } + + public byte[] getResultPayload() { + return resultPayload; + } + + public byte[] getExceptionPayload() { + return exceptionPayload; + } + + /** + * Checks if this call result represents a successful execution. + * + * @return true if the call succeeded (no exception), false otherwise + */ + @JsonIgnore + public boolean isSuccess() { + return exceptionPayload == null; + } + + /** + * Validates if this CallResult matches the given function identifier and arguments digest. + * + * @param functionId the function identifier to match + * @param argsDigest the arguments digest to match + * @return true if both functionId and argsDigest match, false otherwise + */ + public boolean matches(String functionId, String argsDigest) { + return Objects.equals(this.functionId, functionId) + && Objects.equals(this.argsDigest, argsDigest); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CallResult that = (CallResult) o; + return Objects.equals(functionId, that.functionId) + && Objects.equals(argsDigest, that.argsDigest) + && Arrays.equals(resultPayload, that.resultPayload) + && Arrays.equals(exceptionPayload, that.exceptionPayload); + } + + @Override + public int hashCode() { + int result = Objects.hash(functionId, argsDigest); + result = 31 * result + Arrays.hashCode(resultPayload); + result = 31 * result + Arrays.hashCode(exceptionPayload); + return result; + } + + @Override + public String toString() { + return "CallResult{" + + "functionId='" + + functionId + + '\'' + + ", argsDigest='" + + argsDigest + + '\'' + + ", resultPayload=" + + (resultPayload != null ? resultPayload.length + " bytes" : "null") + + ", exceptionPayload=" + + (exceptionPayload != null ? exceptionPayload.length + " bytes" : "null") + + '}'; + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java index eac53d2e..74181d0f 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java @@ -23,7 +23,9 @@ import org.apache.flink.agents.api.OutputEvent; import org.apache.flink.agents.api.context.MemoryUpdate; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.*; @@ -92,7 +94,7 @@ public class ActionStateSerdeTest { @Test public void testActionStateWithNullTaskEvent() throws Exception { // Create ActionState with null taskEvent - ActionState originalState = new ActionState(); + ActionState originalState = new ActionState(null, null, null, null, null, false); MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test value"); originalState.addShortTermMemoryUpdate(memoryUpdate); originalState.addSensoryMemoryUpdate(memoryUpdate); @@ -138,4 +140,117 @@ public class ActionStateSerdeTest { assertEquals("value", deserializedComplexAttr.get("nested")); assertEquals(42, deserializedComplexAttr.get("number")); } + + @Test + public void testActionStateWithCallResults() throws Exception { + // Create ActionState with call results + InputEvent inputEvent = new InputEvent("test input"); + ActionState originalState = new ActionState(inputEvent); + + // Add call results + CallResult result1 = new CallResult("module.func1", "digest1", "result1".getBytes()); + CallResult result2 = + CallResult.ofException("module.func2", "digest2", "exception".getBytes()); + originalState.addCallResult(result1); + originalState.addCallResult(result2); + + // Test serialization/deserialization + ActionStateKafkaSeder seder = new ActionStateKafkaSeder(); + + byte[] serialized = seder.serialize("test-topic", originalState); + ActionState deserializedState = seder.deserialize("test-topic", serialized); + + // Verify call results + assertEquals(2, deserializedState.getCallResultCount()); + + CallResult deserializedResult1 = deserializedState.getCallResult(0); + assertEquals("module.func1", deserializedResult1.getFunctionId()); + assertEquals("digest1", deserializedResult1.getArgsDigest()); + assertArrayEquals("result1".getBytes(), deserializedResult1.getResultPayload()); + assertNull(deserializedResult1.getExceptionPayload()); + assertTrue(deserializedResult1.isSuccess()); + + CallResult deserializedResult2 = deserializedState.getCallResult(1); + assertEquals("module.func2", deserializedResult2.getFunctionId()); + assertEquals("digest2", deserializedResult2.getArgsDigest()); + assertNull(deserializedResult2.getResultPayload()); + assertArrayEquals("exception".getBytes(), deserializedResult2.getExceptionPayload()); + assertFalse(deserializedResult2.isSuccess()); + } + + @Test + public void testActionStateWithCompletedFlag() throws Exception { + // Create completed ActionState + InputEvent inputEvent = new InputEvent("test input"); + List<MemoryUpdate> sensoryUpdates = new ArrayList<>(); + sensoryUpdates.add(new MemoryUpdate("sm.path", "value")); + List<MemoryUpdate> shortTermUpdates = new ArrayList<>(); + shortTermUpdates.add(new MemoryUpdate("stm.path", "value")); + List<Event> outputEvents = new ArrayList<>(); + outputEvents.add(new OutputEvent("output")); + + // Create with completed = true and empty callResults (simulating markCompleted) + ActionState originalState = + new ActionState( + inputEvent, sensoryUpdates, shortTermUpdates, outputEvents, null, true); + + // Test serialization/deserialization + ActionStateKafkaSeder seder = new ActionStateKafkaSeder(); + + byte[] serialized = seder.serialize("test-topic", originalState); + ActionState deserializedState = seder.deserialize("test-topic", serialized); + + // Verify completed flag + assertTrue(deserializedState.isCompleted()); + assertEquals(0, deserializedState.getCallResultCount()); + + // Verify other fields preserved + assertEquals(1, deserializedState.getSensoryMemoryUpdates().size()); + assertEquals(1, deserializedState.getShortTermMemoryUpdates().size()); + assertEquals(1, deserializedState.getOutputEvents().size()); + } + + @Test + public void testActionStateInProgressWithCallResults() throws Exception { + // Create in-progress ActionState with call results (simulating partial execution) + InputEvent inputEvent = new InputEvent("test input"); + List<CallResult> callResults = new ArrayList<>(); + callResults.add(new CallResult("func1", "hash1", "result1".getBytes())); + callResults.add(new CallResult("func2", "hash2", "result2".getBytes())); + + ActionState originalState = + new ActionState(inputEvent, null, null, null, callResults, false); + + // Test serialization/deserialization + ActionStateKafkaSeder seder = new ActionStateKafkaSeder(); + + byte[] serialized = seder.serialize("test-topic", originalState); + ActionState deserializedState = seder.deserialize("test-topic", serialized); + + // Verify state + assertFalse(deserializedState.isCompleted()); + assertEquals(2, deserializedState.getCallResultCount()); + assertTrue(deserializedState.getCallResult(0).matches("func1", "hash1")); + assertTrue(deserializedState.getCallResult(1).matches("func2", "hash2")); + } + + @Test + public void testCallResultWithNullPayloads() throws Exception { + // Test CallResult with null payloads + InputEvent inputEvent = new InputEvent("test"); + ActionState originalState = new ActionState(inputEvent); + originalState.addCallResult(new CallResult("func", "digest", null, null)); + + ActionStateKafkaSeder seder = new ActionStateKafkaSeder(); + + byte[] serialized = seder.serialize("test-topic", originalState); + ActionState deserializedState = seder.deserialize("test-topic", serialized); + + assertEquals(1, deserializedState.getCallResultCount()); + CallResult result = deserializedState.getCallResult(0); + assertEquals("func", result.getFunctionId()); + assertEquals("digest", result.getArgsDigest()); + assertNull(result.getResultPayload()); + assertNull(result.getExceptionPayload()); + } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java new file mode 100644 index 00000000..aa00d119 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateTest.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.context.MemoryUpdate; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** Unit tests for {@link ActionState} with focus on fine-grained durable execution fields. */ +public class ActionStateTest { + + @Test + public void testConstructorWithEvent() { + InputEvent event = new InputEvent("test"); + ActionState state = new ActionState(event); + + assertEquals(event, state.getTaskEvent()); + assertTrue(state.getSensoryMemoryUpdates().isEmpty()); + assertTrue(state.getShortTermMemoryUpdates().isEmpty()); + assertTrue(state.getOutputEvents().isEmpty()); + assertTrue(state.getCallResults().isEmpty()); + assertFalse(state.isCompleted()); + } + + @Test + public void testFullConstructorWithCallResults() { + InputEvent taskEvent = new InputEvent("test"); + List<MemoryUpdate> sensoryUpdates = new ArrayList<>(); + sensoryUpdates.add(new MemoryUpdate("sm.path", "value")); + List<MemoryUpdate> shortTermUpdates = new ArrayList<>(); + shortTermUpdates.add(new MemoryUpdate("stm.path", "value")); + List<org.apache.flink.agents.api.Event> outputEvents = new ArrayList<>(); + outputEvents.add(new OutputEvent("output")); + List<CallResult> callResults = new ArrayList<>(); + callResults.add(new CallResult("func1", "digest1", "result1".getBytes())); + callResults.add(new CallResult("func2", "digest2", "result2".getBytes())); + boolean completed = true; + + ActionState state = + new ActionState( + taskEvent, + sensoryUpdates, + shortTermUpdates, + outputEvents, + callResults, + completed); + + assertEquals(taskEvent, state.getTaskEvent()); + assertEquals(1, state.getSensoryMemoryUpdates().size()); + assertEquals(1, state.getShortTermMemoryUpdates().size()); + assertEquals(1, state.getOutputEvents().size()); + assertEquals(2, state.getCallResults().size()); + assertTrue(state.isCompleted()); + } + + @Test + public void testAddCallResult() { + ActionState state = new ActionState(new InputEvent("test")); + + CallResult result1 = new CallResult("func1", "digest1", "result1".getBytes()); + CallResult result2 = new CallResult("func2", "digest2", "result2".getBytes()); + + state.addCallResult(result1); + assertEquals(1, state.getCallResultCount()); + assertEquals(result1, state.getCallResult(0)); + + state.addCallResult(result2); + assertEquals(2, state.getCallResultCount()); + assertEquals(result2, state.getCallResult(1)); + } + + @Test + public void testGetCallResultOutOfBounds() { + ActionState state = new ActionState(new InputEvent("test")); + + assertNull(state.getCallResult(-1)); + assertNull(state.getCallResult(0)); + assertNull(state.getCallResult(100)); + + state.addCallResult(new CallResult("func", "digest", "result".getBytes())); + assertNull(state.getCallResult(1)); + assertNotNull(state.getCallResult(0)); + } + + @Test + public void testClearCallResults() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + state.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + assertEquals(2, state.getCallResultCount()); + + state.clearCallResults(); + assertEquals(0, state.getCallResultCount()); + assertTrue(state.getCallResults().isEmpty()); + } + + @Test + public void testClearCallResultsFrom() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func0", "digest0", "result0".getBytes())); + state.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + state.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + state.addCallResult(new CallResult("func3", "digest3", "result3".getBytes())); + assertEquals(4, state.getCallResultCount()); + + // Clear from index 2 onwards (keep func0, func1) + state.clearCallResultsFrom(2); + + assertEquals(2, state.getCallResultCount()); + assertEquals("func0", state.getCallResult(0).getFunctionId()); + assertEquals("func1", state.getCallResult(1).getFunctionId()); + } + + @Test + public void testClearCallResultsFromInvalidIndex() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func", "digest", "result".getBytes())); + + // Negative index - should do nothing + state.clearCallResultsFrom(-1); + assertEquals(1, state.getCallResultCount()); + + // Out of bounds index - should do nothing + state.clearCallResultsFrom(10); + assertEquals(1, state.getCallResultCount()); + } + + @Test + public void testClearCallResultsFromZero() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + state.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + + // Clear from index 0 - should clear all + state.clearCallResultsFrom(0); + assertEquals(0, state.getCallResultCount()); + } + + @Test + public void testMarkCompleted() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + state.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + + assertFalse(state.isCompleted()); + assertEquals(2, state.getCallResultCount()); + + state.markCompleted(); + + assertTrue(state.isCompleted()); + assertEquals(0, state.getCallResultCount()); + } + + @Test + public void testEqualsWithCallResultsAndCompleted() { + InputEvent event = new InputEvent("test"); + List<CallResult> callResults1 = new ArrayList<>(); + callResults1.add(new CallResult("func", "digest", "result".getBytes())); + + List<CallResult> callResults2 = new ArrayList<>(); + callResults2.add(new CallResult("func", "digest", "result".getBytes())); + + ActionState state1 = new ActionState(event, null, null, null, callResults1, true); + ActionState state2 = new ActionState(event, null, null, null, callResults2, true); + ActionState state3 = new ActionState(event, null, null, null, callResults1, false); + + assertEquals(state1, state2); + assertNotEquals(state1, state3); // Different completed flag + } + + @Test + public void testHashCodeWithCallResultsAndCompleted() { + InputEvent event = new InputEvent("test"); + List<CallResult> callResults = new ArrayList<>(); + callResults.add(new CallResult("func", "digest", "result".getBytes())); + + ActionState state1 = new ActionState(event, null, null, null, callResults, true); + ActionState state2 = + new ActionState(event, null, null, null, new ArrayList<>(callResults), true); + + assertEquals(state1.hashCode(), state2.hashCode()); + } + + @Test + public void testToStringIncludesNewFields() { + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func", "digest", "result".getBytes())); + state.markCompleted(); + + String str = state.toString(); + + assertTrue(str.contains("callResults")); + assertTrue(str.contains("completed=true")); + } + + @Test + public void testNullListsInFullConstructor() { + ActionState state = new ActionState(null, null, null, null, null, false); + + assertNull(state.getTaskEvent()); + assertNotNull(state.getSensoryMemoryUpdates()); + assertNotNull(state.getShortTermMemoryUpdates()); + assertNotNull(state.getOutputEvents()); + assertNotNull(state.getCallResults()); + assertTrue(state.getSensoryMemoryUpdates().isEmpty()); + assertTrue(state.getShortTermMemoryUpdates().isEmpty()); + assertTrue(state.getOutputEvents().isEmpty()); + assertTrue(state.getCallResults().isEmpty()); + } + + @Test + public void testIntegrationScenario() { + // Simulate a typical fine-grained durable execution flow + + // 1. Create initial state + ActionState state = new ActionState(new InputEvent("test")); + assertFalse(state.isCompleted()); + assertEquals(0, state.getCallResultCount()); + + // 2. First code block completes + CallResult result1 = new CallResult("llm.call", "hash1", "response1".getBytes()); + state.addCallResult(result1); + assertEquals(1, state.getCallResultCount()); + assertFalse(state.isCompleted()); + + // 3. Second code block completes + CallResult result2 = new CallResult("db.query", "hash2", "data".getBytes()); + state.addCallResult(result2); + assertEquals(2, state.getCallResultCount()); + + // 4. Action completes - mark completed and clear results + state.addSensoryMemoryUpdate(new MemoryUpdate("sm.key", "value")); + state.addShortTermMemoryUpdate(new MemoryUpdate("stm.key", "value")); + state.addEvent(new OutputEvent("final_output")); + state.markCompleted(); + + assertTrue(state.isCompleted()); + assertEquals(0, state.getCallResultCount()); // Results cleared + assertEquals(1, state.getSensoryMemoryUpdates().size()); // Memory preserved + assertEquals(1, state.getShortTermMemoryUpdates().size()); + assertEquals(1, state.getOutputEvents().size()); // Events preserved + } + + @Test + public void testRecoveryScenario() { + // Simulate recovery scenario where we need to check call results + + // State from before failure (with 2 completed code blocks) + ActionState recoveredState = new ActionState(new InputEvent("test")); + recoveredState.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + recoveredState.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + + // Check if action is completed + assertFalse(recoveredState.isCompleted()); + + // During re-execution, check if call result matches + CallResult result0 = recoveredState.getCallResult(0); + assertTrue(result0.matches("func1", "digest1")); + assertTrue(result0.isSuccess()); + + CallResult result1 = recoveredState.getCallResult(1); + assertTrue(result1.matches("func2", "digest2")); + + // Third call is new (not in results) + assertNull(recoveredState.getCallResult(2)); + } + + @Test + public void testNonDeterministicRecovery() { + // Simulate detection of non-deterministic call order + ActionState state = new ActionState(new InputEvent("test")); + state.addCallResult(new CallResult("func1", "digest1", "result1".getBytes())); + state.addCallResult(new CallResult("func2", "digest2", "result2".getBytes())); + state.addCallResult(new CallResult("func3", "digest3", "result3".getBytes())); + + // During recovery, call 1 matches + CallResult result0 = state.getCallResult(0); + assertTrue(result0.matches("func1", "digest1")); + + // Call 2 doesn't match (different function called) + CallResult result1 = state.getCallResult(1); + assertFalse(result1.matches("different_func", "digest2")); + + // Clear results from index 1 onwards + state.clearCallResultsFrom(1); + assertEquals(1, state.getCallResultCount()); + assertEquals("func1", state.getCallResult(0).getFunctionId()); + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java new file mode 100644 index 00000000..11d8eb14 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/CallResultTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** Unit tests for {@link CallResult}. */ +public class CallResultTest { + + @Test + public void testSuccessfulCallResult() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + + CallResult result = new CallResult(functionId, argsDigest, resultPayload); + + assertEquals(functionId, result.getFunctionId()); + assertEquals(argsDigest, result.getArgsDigest()); + assertArrayEquals(resultPayload, result.getResultPayload()); + assertNull(result.getExceptionPayload()); + assertTrue(result.isSuccess()); + } + + @Test + public void testFailedCallResult() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] exceptionPayload = "exception".getBytes(); + + CallResult result = CallResult.ofException(functionId, argsDigest, exceptionPayload); + + assertEquals(functionId, result.getFunctionId()); + assertEquals(argsDigest, result.getArgsDigest()); + assertNull(result.getResultPayload()); + assertArrayEquals(exceptionPayload, result.getExceptionPayload()); + assertFalse(result.isSuccess()); + } + + @Test + public void testFullConstructor() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + byte[] exceptionPayload = null; + + CallResult result = new CallResult(functionId, argsDigest, resultPayload, exceptionPayload); + + assertEquals(functionId, result.getFunctionId()); + assertEquals(argsDigest, result.getArgsDigest()); + assertArrayEquals(resultPayload, result.getResultPayload()); + assertNull(result.getExceptionPayload()); + assertTrue(result.isSuccess()); + } + + @Test + public void testMatches() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + + CallResult result = new CallResult(functionId, argsDigest, resultPayload); + + assertTrue(result.matches(functionId, argsDigest)); + assertFalse(result.matches("other_function", argsDigest)); + assertFalse(result.matches(functionId, "other_digest")); + assertFalse(result.matches("other_function", "other_digest")); + } + + @Test + public void testMatchesWithNullValues() { + CallResult result = new CallResult(); + + assertTrue(result.matches(null, null)); + assertFalse(result.matches("function", null)); + assertFalse(result.matches(null, "digest")); + } + + @Test + public void testEquals() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + + CallResult result1 = new CallResult(functionId, argsDigest, resultPayload); + CallResult result2 = new CallResult(functionId, argsDigest, resultPayload); + CallResult result3 = new CallResult("other", argsDigest, resultPayload); + + assertEquals(result1, result2); + assertNotEquals(result1, result3); + assertNotEquals(result1, null); + assertNotEquals(result1, "string"); + assertEquals(result1, result1); + } + + @Test + public void testHashCode() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + + CallResult result1 = new CallResult(functionId, argsDigest, resultPayload); + CallResult result2 = new CallResult(functionId, argsDigest, resultPayload); + + assertEquals(result1.hashCode(), result2.hashCode()); + } + + @Test + public void testToString() { + String functionId = "my_module.my_function"; + String argsDigest = "abc123"; + byte[] resultPayload = "result".getBytes(); + + CallResult result = new CallResult(functionId, argsDigest, resultPayload); + String str = result.toString(); + + assertTrue(str.contains(functionId)); + assertTrue(str.contains(argsDigest)); + assertTrue(str.contains("bytes")); + } + + @Test + public void testToStringWithNullPayloads() { + CallResult result = new CallResult("func", "digest", null, null); + String str = result.toString(); + + assertTrue(str.contains("null")); + } + + @Test + public void testDefaultConstructor() { + CallResult result = new CallResult(); + + assertNull(result.getFunctionId()); + assertNull(result.getArgsDigest()); + assertNull(result.getResultPayload()); + assertNull(result.getExceptionPayload()); + assertTrue(result.isSuccess()); // exceptionPayload is null + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java index c285adf3..cd32524b 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStoreTest.java @@ -159,7 +159,12 @@ public class KafkaActionStateStoreTest { 3L)); for (int i = 0; i < 5; i++) { mockConsumer.addRecord( - new ConsumerRecord<>(TEST_TOPIC, 0, i++, "key", new ActionState())); + new ConsumerRecord<>( + TEST_TOPIC, + 0, + i++, + "key", + new ActionState(null, null, null, null, null, false))); } // Test getting recovery marker after putting state Object secondMarker = actionStateStore.getRecoveryMarker();
