This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 901a96d97e9a8b8f0b74db6df781e300b9a5df68 Author: WenjinXie <[email protected]> AuthorDate: Tue Nov 4 16:17:27 2025 +0800 [api][runtime][java] Introduce sensory memory in java. fix --- .../apache/flink/agents/api/context/MemoryRef.java | 6 +- .../flink/agents/api/context/RunnerContext.java | 8 ++ .../agents/integration/test/MemoryObjectAgent.java | 91 ++++++++++++++++------ .../agents/runtime/actionstate/ActionState.java | 53 +++++++++---- .../agents/runtime/context/RunnerContextImpl.java | 46 ++++++++--- .../agents/runtime/memory/CachedMemoryStore.java | 5 ++ .../flink/agents/runtime/memory/MemoryStore.java | 3 + .../runtime/operator/ActionExecutionOperator.java | 31 +++++++- .../python/context/PythonRunnerContextImpl.java | 10 ++- .../runtime/actionstate/ActionStateSerdeTest.java | 26 +++++-- .../flink/agents/runtime/memory/MemoryRefTest.java | 7 +- .../operator/ActionExecutionOperatorTest.java | 7 +- 12 files changed, 226 insertions(+), 67 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java index 8f0a133..219909a 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java @@ -47,12 +47,12 @@ public final class MemoryRef implements Serializable { /** * Resolves the reference using the provided RunnerContext to get the actual data. * - * @param ctx The current execution context, used to access Short-Term Memory. + * @param memory The memory this ref based on. * @return The deserialized, original data object. * @throws Exception if the memory cannot be accessed or the data cannot be resolved. */ - public MemoryObject resolve(RunnerContext ctx) throws Exception { - return ctx.getShortTermMemory().get(this); + public MemoryObject resolve(MemoryObject memory) throws Exception { + return memory.get(this); } public String getPath() { diff --git a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java index 5960124..1d64060 100644 --- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java +++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java @@ -37,6 +37,14 @@ public interface RunnerContext { */ void sendEvent(Event event); + /** + * Gets the sensory memory. + * + * @return MemoryObject the root of the sensory memory + * @throws Exception if the underlying state backend cannot be accessed + */ + MemoryObject getSensoryMemory() throws Exception; + /** * Gets the short-term memory. * diff --git a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java index a54316c..da1a03a 100644 --- a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java +++ b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java @@ -30,6 +30,17 @@ import java.util.*; /** An example agent that tests usages of MemoryObject. */ public class MemoryObjectAgent extends Agent { + public static class MyEvent extends Event { + private final String value; + + public MyEvent(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + } /** A custom POJO for testing serialization. */ public static class Person implements Serializable { @@ -65,6 +76,8 @@ public class MemoryObjectAgent extends Agent { @Action(listenEvents = {InputEvent.class}) public static void testMemoryObject(Event event, RunnerContext ctx) throws Exception { MemoryObject stm = ctx.getShortTermMemory(); + MemoryObject sm = ctx.getSensoryMemory(); + Integer key = (Integer) ((InputEvent) event).getInput(); int visitCount = 1; @@ -73,42 +86,70 @@ public class MemoryObjectAgent extends Agent { } stm.set("visit_count", visitCount); - // isExist - stm.set("existing.path", true); - assertEquals(stm.isExist("existing.path"), true); - assertEquals(stm.isExist("non.existing.path"), false); - - // getFieldNames and getFields - MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true); - fieldsTestObj.set("x", 1); - fieldsTestObj.set("y", 2); - fieldsTestObj.newObject("obj", false); - List<String> names = fieldsTestObj.getFieldNames(); - assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x", "y", "obj")), true); - Map<String, Object> fields = fieldsTestObj.getFields(); - assertEquals(1, ((Number) fields.get("x")).intValue()); - assertEquals("NestedObject", fields.get("obj")); - - // List List<String> tags = Arrays.asList("gamer", "developer", "flink-user"); - stm.set("list", tags); - assertEquals(tags, stm.get("list").getValue()); - // Map Map<String, Integer> inventory = new HashMap<>(); inventory.put("potion", 10); inventory.put("gold", 500); - stm.set("map", inventory); - assertEquals(inventory, stm.get("map").getValue()); - // Custom POJO Person person = new Person("Bob", 22); - stm.set("person", person); - assertEquals(person, stm.get("person").getValue()); + + if (visitCount == 1) { + // Test sensory memory + sm.set("existing.path", true); + assertEquals(sm.isExist("existing"), true); + assertEquals(sm.isExist("existing.path"), true); + + // Test short-term memory + // exist + stm.set("existing.path", true); + + // getFieldNames and getFields + MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true); + fieldsTestObj.set("x", 1); + fieldsTestObj.set("y", 2); + fieldsTestObj.newObject("obj", false); + + // List + stm.set("list", tags); + + // Map + stm.set("map", inventory); + + // Custom POJO + stm.set("person", person); + } else { + // Test sensory memory + assertEquals(sm.isExist("existing"), false); + assertEquals(sm.isExist("existing.path"), false); + + // Test short-term memory + // exist + assertEquals(stm.isExist("existing.path"), true); + assertEquals(stm.isExist("non.existing.path"), false); + + // getFieldNames and getFields + MemoryObject fieldsTestObj = stm.get("fieldsTest"); + List<String> names = fieldsTestObj.getFieldNames(); + assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x", "y", "obj")), true); + Map<String, Object> fields = fieldsTestObj.getFields(); + assertEquals(1, ((Number) fields.get("x")).intValue()); + assertEquals("NestedObject", fields.get("obj")); + + // List + assertEquals(tags, stm.get("list").getValue()); + + // Map + assertEquals(inventory, stm.get("map").getValue()); + + // Custom POJO + assertEquals(person, stm.get("person").getValue()); + } String result = String.format("All assertions passed for key: %d (visit #%d)", key, visitCount); String output = result + " [Agent Complete]"; + ctx.sendEvent(new OutputEvent(output)); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java index 5a7b1ff..34eefb3 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java @@ -26,27 +26,36 @@ import java.util.List; /** Class representing the state of an action after processing an event. */ public class ActionState { private final Event taskEvent; - private final List<MemoryUpdate> memoryUpdates; + private final List<MemoryUpdate> sensoryMemoryUpdates; + private final List<MemoryUpdate> shortTermMemoryUpdates; private final List<Event> outputEvents; /** Constructs a new TaskActionState instance. */ public ActionState(final Event taskEvent) { this.taskEvent = taskEvent; - this.memoryUpdates = new ArrayList<>(); + this.sensoryMemoryUpdates = new ArrayList<>(); + this.shortTermMemoryUpdates = new ArrayList<>(); this.outputEvents = new ArrayList<>(); } public ActionState() { this.taskEvent = null; - this.memoryUpdates = new ArrayList<>(); + this.sensoryMemoryUpdates = new ArrayList<>(); + this.shortTermMemoryUpdates = new ArrayList<>(); this.outputEvents = new ArrayList<>(); } /** Constructor for deserialization purposes. */ public ActionState( - Event taskEvent, List<MemoryUpdate> memoryUpdates, List<Event> outputEvents) { + Event taskEvent, + List<MemoryUpdate> sensoryMemoryUpdates, + List<MemoryUpdate> shortTermMemoryUpdates, + List<Event> outputEvents) { this.taskEvent = taskEvent; - this.memoryUpdates = memoryUpdates != null ? memoryUpdates : new ArrayList<>(); + this.sensoryMemoryUpdates = + sensoryMemoryUpdates != null ? sensoryMemoryUpdates : new ArrayList<>(); + this.shortTermMemoryUpdates = + shortTermMemoryUpdates != null ? shortTermMemoryUpdates : new ArrayList<>(); this.outputEvents = outputEvents != null ? outputEvents : new ArrayList<>(); } @@ -55,8 +64,12 @@ public class ActionState { return taskEvent; } - public List<MemoryUpdate> getMemoryUpdates() { - return memoryUpdates; + public List<MemoryUpdate> getSensoryMemoryUpdates() { + return sensoryMemoryUpdates; + } + + public List<MemoryUpdate> getShortTermMemoryUpdates() { + return shortTermMemoryUpdates; } public List<Event> getOutputEvents() { @@ -64,8 +77,13 @@ public class ActionState { } /** Setters for the fields */ - public void addMemoryUpdate(MemoryUpdate memoryUpdate) { - memoryUpdates.add(memoryUpdate); + public void addSensoryMemoryUpdate(MemoryUpdate memoryUpdate) { + sensoryMemoryUpdates.add(memoryUpdate); + } + + /** Setters for the fields */ + public void addShortTermMemoryUpdate(MemoryUpdate memoryUpdate) { + shortTermMemoryUpdates.add(memoryUpdate); } public void addEvent(Event event) { @@ -75,8 +93,15 @@ public class ActionState { @Override public int hashCode() { int result = taskEvent != null ? taskEvent.hashCode() : 0; - result = 31 * result + (memoryUpdates != null ? memoryUpdates.hashCode() : 0); - result = 31 * result + (outputEvents != null ? outputEvents.hashCode() : 0); + result = + 31 * result + + (sensoryMemoryUpdates.isEmpty() ? 0 : sensoryMemoryUpdates.hashCode()); + result = + 31 * result + + (shortTermMemoryUpdates.isEmpty() + ? 0 + : shortTermMemoryUpdates.hashCode()); + result = 31 * result + (outputEvents.isEmpty() ? 0 : outputEvents.hashCode()); return result; } @@ -85,8 +110,10 @@ public class ActionState { return "TaskActionState{" + "taskEvent=" + taskEvent - + ", memoryUpdates=" - + memoryUpdates + + ", sensoryMemoryUpdates=" + + sensoryMemoryUpdates + + ", shortTermMemoryUpdates=" + + shortTermMemoryUpdates + ", outputEvents=" + outputEvents + '}'; diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index d998af6..e02ec45 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -44,23 +44,28 @@ import java.util.Map; public class RunnerContextImpl implements RunnerContext { protected final List<Event> pendingEvents = new ArrayList<>(); - protected final CachedMemoryStore store; + protected final CachedMemoryStore sensoryMemStore; + protected final CachedMemoryStore shortTermMemStore; protected final FlinkAgentsMetricGroupImpl agentMetricGroup; protected final Runnable mailboxThreadChecker; protected final AgentPlan agentPlan; - protected final List<MemoryUpdate> memoryUpdates; + protected final List<MemoryUpdate> sensoryMemoryUpdates; + protected final List<MemoryUpdate> shortTermMemoryUpdates; protected String actionName; public RunnerContextImpl( - CachedMemoryStore store, + CachedMemoryStore sensoryMemStore, + CachedMemoryStore shortTermMemStore, FlinkAgentsMetricGroupImpl agentMetricGroup, Runnable mailboxThreadChecker, AgentPlan agentPlan) { - this.store = store; + this.sensoryMemStore = sensoryMemStore; + this.shortTermMemStore = shortTermMemStore; this.agentMetricGroup = agentMetricGroup; this.mailboxThreadChecker = mailboxThreadChecker; this.agentPlan = agentPlan; - this.memoryUpdates = new LinkedList<>(); + this.sensoryMemoryUpdates = new LinkedList<>(); + this.shortTermMemoryUpdates = new LinkedList<>(); } public void setActionName(String actionName) { @@ -105,6 +110,11 @@ public class RunnerContextImpl implements RunnerContext { this.pendingEvents.isEmpty(), "There are pending events remaining in the context."); } + public List<MemoryUpdate> getSensoryMemoryUpdates() { + mailboxThreadChecker.run(); + return List.copyOf(sensoryMemoryUpdates); + } + /** * Gets all the updates made to this MemoryObject since it was created or the last time this * method was called. This method lives here because it is internally used by the ActionTask to @@ -112,16 +122,29 @@ public class RunnerContextImpl implements RunnerContext { * * @return list of memory updates */ - public List<MemoryUpdate> getAllMemoryUpdates() { + public List<MemoryUpdate> getShortTermMemoryUpdates() { + mailboxThreadChecker.run(); + return List.copyOf(shortTermMemoryUpdates); + } + + @Override + public MemoryObject getSensoryMemory() throws Exception { mailboxThreadChecker.run(); - return List.copyOf(memoryUpdates); + return new MemoryObjectImpl( + sensoryMemStore, + MemoryObjectImpl.ROOT_KEY, + mailboxThreadChecker, + sensoryMemoryUpdates); } @Override public MemoryObject getShortTermMemory() throws Exception { mailboxThreadChecker.run(); return new MemoryObjectImpl( - store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker, memoryUpdates); + shortTermMemStore, + MemoryObjectImpl.ROOT_KEY, + mailboxThreadChecker, + shortTermMemoryUpdates); } @Override @@ -152,6 +175,11 @@ public class RunnerContextImpl implements RunnerContext { } public void persistMemory() throws Exception { - store.persistCache(); + sensoryMemStore.persistCache(); + shortTermMemStore.persistCache(); + } + + public void clearSensoryMemory() throws Exception { + sensoryMemStore.clear(); } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java index 71eb8d2..36360cd 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java @@ -57,4 +57,9 @@ public class CachedMemoryStore implements MemoryStore { } cache.clear(); } + + public void clear() throws Exception { + cache.clear(); + store.clear(); + } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java index f466750..28793fb 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java @@ -45,4 +45,7 @@ public interface MemoryStore { * @return true if the MemoryItem exists, false otherwise */ boolean contains(String key) throws Exception; + + /** Remove all the MemoryItem. */ + void clear() throws Exception; } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 95b991b..2a54563 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -115,6 +115,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private transient StreamRecord<OUT> reusedStreamRecord; + private transient MapState<String, MemoryObjectImpl.MemoryItem> sensoryMemState; + private transient MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState; // PythonActionExecutor for Python actions @@ -182,6 +184,13 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT public void open() throws Exception { super.open(); reusedStreamRecord = new StreamRecord<>(null); + // init sensoryMemState + MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> sensoryMemStateDescriptor = + new MapStateDescriptor<>( + "sensoryMemory", + TypeInformation.of(String.class), + TypeInformation.of(MemoryObjectImpl.MemoryItem.class)); + sensoryMemState = getRuntimeContext().getMapState(sensoryMemStateDescriptor); // init shortTermMemState MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> shortTermMemStateDescriptor = @@ -397,12 +406,19 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (actionState != null) { isFinished = true; outputEvents = actionState.getOutputEvents(); - for (MemoryUpdate memoryUpdate : actionState.getMemoryUpdates()) { + for (MemoryUpdate memoryUpdate : actionState.getShortTermMemoryUpdates()) { actionTask .getRunnerContext() .getShortTermMemory() .set(memoryUpdate.getPath(), memoryUpdate.getValue()); } + + for (MemoryUpdate memoryUpdate : actionState.getSensoryMemoryUpdates()) { + actionTask + .getRunnerContext() + .getSensoryMemory() + .set(memoryUpdate.getPath(), memoryUpdate.getValue()); + } } else { maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke(); @@ -452,6 +468,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // 3. Process the next InputEvent or next action task if (currentInputEventFinished) { + // Clean up sensory memory when a single run finished. + actionTask.getRunnerContext().clearSensoryMemory(); + // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. int removedCount = removeFromListState(currentProcessingKeysOpState, key); @@ -659,6 +678,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT } else if (actionTask.action.getExec() instanceof JavaFunction) { runnerContext = new RunnerContextImpl( + new CachedMemoryStore(sensoryMemState), new CachedMemoryStore(shortTermMemState), metricGroup, this::checkMailboxThread, @@ -666,6 +686,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT } else if (actionTask.action.getExec() instanceof PythonFunction) { runnerContext = new PythonRunnerContextImpl( + new CachedMemoryStore(sensoryMemState), new CachedMemoryStore(shortTermMemState), metricGroup, this::checkMailboxThread, @@ -755,8 +776,12 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); - for (MemoryUpdate memoryUpdate : context.getAllMemoryUpdates()) { - actionState.addMemoryUpdate(memoryUpdate); + for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) { + actionState.addSensoryMemoryUpdate(memoryUpdate); + } + + for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) { + actionState.addShortTermMemoryUpdate(memoryUpdate); } for (Event outputEvent : actionTaskResult.getOutputEvents()) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java index 89f741d..4bdb8d8 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java @@ -36,12 +36,18 @@ public class PythonRunnerContextImpl extends RunnerContextImpl { private final PythonActionExecutor pythonActionExecutor; public PythonRunnerContextImpl( - CachedMemoryStore store, + CachedMemoryStore sensoryMemStore, + CachedMemoryStore shortTermMemStore, FlinkAgentsMetricGroupImpl agentMetricGroup, Runnable mailboxThreadChecker, AgentPlan agentPlan, PythonActionExecutor pythonActionExecutor) { - super(store, agentMetricGroup, mailboxThreadChecker, agentPlan); + super( + sensoryMemStore, + shortTermMemStore, + agentMetricGroup, + mailboxThreadChecker, + agentPlan); this.pythonActionExecutor = pythonActionExecutor; } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java index 4f9ca04..eac53d2 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java @@ -40,11 +40,13 @@ public class ActionStateSerdeTest { OutputEvent outputEvent = new OutputEvent("test output"); outputEvent.setAttr("outputAttr", 123); - MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test value"); + MemoryUpdate sensoryMemoryUpdate = new MemoryUpdate("sm.test.path", "sm test value"); + MemoryUpdate shortTermMemoryUpdate = new MemoryUpdate("stm.test.path", "stm test value"); // Create ActionState ActionState originalState = new ActionState(inputEvent); - originalState.addMemoryUpdate(memoryUpdate); + originalState.addSensoryMemoryUpdate(sensoryMemoryUpdate); + originalState.addShortTermMemoryUpdate(shortTermMemoryUpdate); originalState.addEvent(outputEvent); // Test Kafka seder/deserializer @@ -67,10 +69,16 @@ public class ActionStateSerdeTest { assertEquals("testValue", deserializedInputEvent.getAttr("testAttr")); // Verify memoryUpdates - assertEquals(1, deserializedState.getMemoryUpdates().size()); - MemoryUpdate deserializedMemoryUpdate = deserializedState.getMemoryUpdates().get(0); - assertEquals("test.path", deserializedMemoryUpdate.getPath()); - assertEquals("test value", deserializedMemoryUpdate.getValue()); + assertEquals(1, deserializedState.getSensoryMemoryUpdates().size()); + MemoryUpdate deserializedSensoryMemoryUpdate = + deserializedState.getSensoryMemoryUpdates().get(0); + assertEquals("sm.test.path", deserializedSensoryMemoryUpdate.getPath()); + assertEquals("sm test value", deserializedSensoryMemoryUpdate.getValue()); + assertEquals(1, deserializedState.getShortTermMemoryUpdates().size()); + MemoryUpdate deserializedShortTermMemoryUpdate = + deserializedState.getShortTermMemoryUpdates().get(0); + assertEquals("stm.test.path", deserializedShortTermMemoryUpdate.getPath()); + assertEquals("stm test value", deserializedShortTermMemoryUpdate.getValue()); // Verify outputEvents assertEquals(1, deserializedState.getOutputEvents().size()); @@ -86,7 +94,8 @@ public class ActionStateSerdeTest { // Create ActionState with null taskEvent ActionState originalState = new ActionState(); MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test value"); - originalState.addMemoryUpdate(memoryUpdate); + originalState.addShortTermMemoryUpdate(memoryUpdate); + originalState.addSensoryMemoryUpdate(memoryUpdate); // Test serialization/deserialization ActionStateKafkaSeder seder = new ActionStateKafkaSeder(); @@ -98,7 +107,8 @@ public class ActionStateSerdeTest { assertNull(deserializedState.getTaskEvent()); // Verify other fields - assertEquals(1, deserializedState.getMemoryUpdates().size()); + assertEquals(1, deserializedState.getSensoryMemoryUpdates().size()); + assertEquals(1, deserializedState.getShortTermMemoryUpdates().size()); assertEquals(0, deserializedState.getOutputEvents().size()); } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java index 780784f..46a68d9 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java @@ -68,6 +68,11 @@ public class MemoryRefTest { this.memoryObject = memoryObject; } + @Override + public MemoryObject getSensoryMemory() throws Exception { + return memoryObject; + } + @Override public MemoryObject getShortTermMemory() { return memoryObject; @@ -176,7 +181,7 @@ public class MemoryRefTest { for (Map.Entry<String, Object> entry : testData.entrySet()) { MemoryRef ref = memory.set(entry.getKey(), entry.getValue()); - Object resolvedValue = ref.resolve(ctx).getValue(); + Object resolvedValue = ref.resolve(ctx.getShortTermMemory()).getValue(); assertEquals(entry.getValue(), resolvedValue); } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 4027589..4c58f98 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -267,10 +267,11 @@ public class ActionExecutionOperatorTest { assertThat(taskEvent).isNotNull(); // Verify memory updates contain expected data - if (!state.getMemoryUpdates().isEmpty()) { + if (!state.getShortTermMemoryUpdates().isEmpty()) { // For action1, memory should contain input + 1 - assertThat(state.getMemoryUpdates().get(0).getPath()).isEqualTo("tmp"); - assertThat(state.getMemoryUpdates().get(0).getValue()) + assertThat(state.getShortTermMemoryUpdates().get(0).getPath()) + .isEqualTo("tmp"); + assertThat(state.getShortTermMemoryUpdates().get(0).getValue()) .isEqualTo(inputValue + 1); }
