This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 901a96d97e9a8b8f0b74db6df781e300b9a5df68
Author: WenjinXie <[email protected]>
AuthorDate: Tue Nov 4 16:17:27 2025 +0800

    [api][runtime][java] Introduce sensory memory in java.
    
    fix
---
 .../apache/flink/agents/api/context/MemoryRef.java |  6 +-
 .../flink/agents/api/context/RunnerContext.java    |  8 ++
 .../agents/integration/test/MemoryObjectAgent.java | 91 ++++++++++++++++------
 .../agents/runtime/actionstate/ActionState.java    | 53 +++++++++----
 .../agents/runtime/context/RunnerContextImpl.java  | 46 ++++++++---
 .../agents/runtime/memory/CachedMemoryStore.java   |  5 ++
 .../flink/agents/runtime/memory/MemoryStore.java   |  3 +
 .../runtime/operator/ActionExecutionOperator.java  | 31 +++++++-
 .../python/context/PythonRunnerContextImpl.java    | 10 ++-
 .../runtime/actionstate/ActionStateSerdeTest.java  | 26 +++++--
 .../flink/agents/runtime/memory/MemoryRefTest.java |  7 +-
 .../operator/ActionExecutionOperatorTest.java      |  7 +-
 12 files changed, 226 insertions(+), 67 deletions(-)

diff --git 
a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java 
b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
index 8f0a133..219909a 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/MemoryRef.java
@@ -47,12 +47,12 @@ public final class MemoryRef implements Serializable {
     /**
      * Resolves the reference using the provided RunnerContext to get the 
actual data.
      *
-     * @param ctx The current execution context, used to access Short-Term 
Memory.
+     * @param memory The memory this ref based on.
      * @return The deserialized, original data object.
      * @throws Exception if the memory cannot be accessed or the data cannot 
be resolved.
      */
-    public MemoryObject resolve(RunnerContext ctx) throws Exception {
-        return ctx.getShortTermMemory().get(this);
+    public MemoryObject resolve(MemoryObject memory) throws Exception {
+        return memory.get(this);
     }
 
     public String getPath() {
diff --git 
a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java 
b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
index 5960124..1d64060 100644
--- a/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
+++ b/api/src/main/java/org/apache/flink/agents/api/context/RunnerContext.java
@@ -37,6 +37,14 @@ public interface RunnerContext {
      */
     void sendEvent(Event event);
 
+    /**
+     * Gets the sensory memory.
+     *
+     * @return MemoryObject the root of the sensory memory
+     * @throws Exception if the underlying state backend cannot be accessed
+     */
+    MemoryObject getSensoryMemory() throws Exception;
+
     /**
      * Gets the short-term memory.
      *
diff --git 
a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
 
b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
index a54316c..da1a03a 100644
--- 
a/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
+++ 
b/e2e-test/integration-test/src/main/java/org/apache/flink/agents/integration/test/MemoryObjectAgent.java
@@ -30,6 +30,17 @@ import java.util.*;
 
 /** An example agent that tests usages of MemoryObject. */
 public class MemoryObjectAgent extends Agent {
+    public static class MyEvent extends Event {
+        private final String value;
+
+        public MyEvent(String value) {
+            this.value = value;
+        }
+
+        public String getValue() {
+            return value;
+        }
+    }
 
     /** A custom POJO for testing serialization. */
     public static class Person implements Serializable {
@@ -65,6 +76,8 @@ public class MemoryObjectAgent extends Agent {
     @Action(listenEvents = {InputEvent.class})
     public static void testMemoryObject(Event event, RunnerContext ctx) throws 
Exception {
         MemoryObject stm = ctx.getShortTermMemory();
+        MemoryObject sm = ctx.getSensoryMemory();
+
         Integer key = (Integer) ((InputEvent) event).getInput();
 
         int visitCount = 1;
@@ -73,42 +86,70 @@ public class MemoryObjectAgent extends Agent {
         }
         stm.set("visit_count", visitCount);
 
-        // isExist
-        stm.set("existing.path", true);
-        assertEquals(stm.isExist("existing.path"), true);
-        assertEquals(stm.isExist("non.existing.path"), false);
-
-        // getFieldNames and getFields
-        MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true);
-        fieldsTestObj.set("x", 1);
-        fieldsTestObj.set("y", 2);
-        fieldsTestObj.newObject("obj", false);
-        List<String> names = fieldsTestObj.getFieldNames();
-        assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x", "y", 
"obj")), true);
-        Map<String, Object> fields = fieldsTestObj.getFields();
-        assertEquals(1, ((Number) fields.get("x")).intValue());
-        assertEquals("NestedObject", fields.get("obj"));
-
-        // List
         List<String> tags = Arrays.asList("gamer", "developer", "flink-user");
-        stm.set("list", tags);
-        assertEquals(tags, stm.get("list").getValue());
 
-        // Map
         Map<String, Integer> inventory = new HashMap<>();
         inventory.put("potion", 10);
         inventory.put("gold", 500);
-        stm.set("map", inventory);
-        assertEquals(inventory, stm.get("map").getValue());
 
-        // Custom POJO
         Person person = new Person("Bob", 22);
-        stm.set("person", person);
-        assertEquals(person, stm.get("person").getValue());
+
+        if (visitCount == 1) {
+            // Test sensory memory
+            sm.set("existing.path", true);
+            assertEquals(sm.isExist("existing"), true);
+            assertEquals(sm.isExist("existing.path"), true);
+
+            // Test short-term memory
+            // exist
+            stm.set("existing.path", true);
+
+            // getFieldNames and getFields
+            MemoryObject fieldsTestObj = stm.newObject("fieldsTest", true);
+            fieldsTestObj.set("x", 1);
+            fieldsTestObj.set("y", 2);
+            fieldsTestObj.newObject("obj", false);
+
+            // List
+            stm.set("list", tags);
+
+            // Map
+            stm.set("map", inventory);
+
+            // Custom POJO
+            stm.set("person", person);
+        } else {
+            // Test sensory memory
+            assertEquals(sm.isExist("existing"), false);
+            assertEquals(sm.isExist("existing.path"), false);
+
+            // Test short-term memory
+            // exist
+            assertEquals(stm.isExist("existing.path"), true);
+            assertEquals(stm.isExist("non.existing.path"), false);
+
+            // getFieldNames and getFields
+            MemoryObject fieldsTestObj = stm.get("fieldsTest");
+            List<String> names = fieldsTestObj.getFieldNames();
+            assertEquals(new HashSet<>(names).containsAll(Arrays.asList("x", 
"y", "obj")), true);
+            Map<String, Object> fields = fieldsTestObj.getFields();
+            assertEquals(1, ((Number) fields.get("x")).intValue());
+            assertEquals("NestedObject", fields.get("obj"));
+
+            // List
+            assertEquals(tags, stm.get("list").getValue());
+
+            // Map
+            assertEquals(inventory, stm.get("map").getValue());
+
+            // Custom POJO
+            assertEquals(person, stm.get("person").getValue());
+        }
 
         String result =
                 String.format("All assertions passed for key: %d (visit #%d)", 
key, visitCount);
         String output = result + " [Agent Complete]";
+
         ctx.sendEvent(new OutputEvent(output));
     }
 
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
index 5a7b1ff..34eefb3 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java
@@ -26,27 +26,36 @@ import java.util.List;
 /** Class representing the state of an action after processing an event. */
 public class ActionState {
     private final Event taskEvent;
-    private final List<MemoryUpdate> memoryUpdates;
+    private final List<MemoryUpdate> sensoryMemoryUpdates;
+    private final List<MemoryUpdate> shortTermMemoryUpdates;
     private final List<Event> outputEvents;
 
     /** Constructs a new TaskActionState instance. */
     public ActionState(final Event taskEvent) {
         this.taskEvent = taskEvent;
-        this.memoryUpdates = new ArrayList<>();
+        this.sensoryMemoryUpdates = new ArrayList<>();
+        this.shortTermMemoryUpdates = new ArrayList<>();
         this.outputEvents = new ArrayList<>();
     }
 
     public ActionState() {
         this.taskEvent = null;
-        this.memoryUpdates = new ArrayList<>();
+        this.sensoryMemoryUpdates = new ArrayList<>();
+        this.shortTermMemoryUpdates = new ArrayList<>();
         this.outputEvents = new ArrayList<>();
     }
 
     /** Constructor for deserialization purposes. */
     public ActionState(
-            Event taskEvent, List<MemoryUpdate> memoryUpdates, List<Event> 
outputEvents) {
+            Event taskEvent,
+            List<MemoryUpdate> sensoryMemoryUpdates,
+            List<MemoryUpdate> shortTermMemoryUpdates,
+            List<Event> outputEvents) {
         this.taskEvent = taskEvent;
-        this.memoryUpdates = memoryUpdates != null ? memoryUpdates : new 
ArrayList<>();
+        this.sensoryMemoryUpdates =
+                sensoryMemoryUpdates != null ? sensoryMemoryUpdates : new 
ArrayList<>();
+        this.shortTermMemoryUpdates =
+                shortTermMemoryUpdates != null ? shortTermMemoryUpdates : new 
ArrayList<>();
         this.outputEvents = outputEvents != null ? outputEvents : new 
ArrayList<>();
     }
 
@@ -55,8 +64,12 @@ public class ActionState {
         return taskEvent;
     }
 
-    public List<MemoryUpdate> getMemoryUpdates() {
-        return memoryUpdates;
+    public List<MemoryUpdate> getSensoryMemoryUpdates() {
+        return sensoryMemoryUpdates;
+    }
+
+    public List<MemoryUpdate> getShortTermMemoryUpdates() {
+        return shortTermMemoryUpdates;
     }
 
     public List<Event> getOutputEvents() {
@@ -64,8 +77,13 @@ public class ActionState {
     }
 
     /** Setters for the fields */
-    public void addMemoryUpdate(MemoryUpdate memoryUpdate) {
-        memoryUpdates.add(memoryUpdate);
+    public void addSensoryMemoryUpdate(MemoryUpdate memoryUpdate) {
+        sensoryMemoryUpdates.add(memoryUpdate);
+    }
+
+    /** Setters for the fields */
+    public void addShortTermMemoryUpdate(MemoryUpdate memoryUpdate) {
+        shortTermMemoryUpdates.add(memoryUpdate);
     }
 
     public void addEvent(Event event) {
@@ -75,8 +93,15 @@ public class ActionState {
     @Override
     public int hashCode() {
         int result = taskEvent != null ? taskEvent.hashCode() : 0;
-        result = 31 * result + (memoryUpdates != null ? 
memoryUpdates.hashCode() : 0);
-        result = 31 * result + (outputEvents != null ? outputEvents.hashCode() 
: 0);
+        result =
+                31 * result
+                        + (sensoryMemoryUpdates.isEmpty() ? 0 : 
sensoryMemoryUpdates.hashCode());
+        result =
+                31 * result
+                        + (shortTermMemoryUpdates.isEmpty()
+                                ? 0
+                                : shortTermMemoryUpdates.hashCode());
+        result = 31 * result + (outputEvents.isEmpty() ? 0 : 
outputEvents.hashCode());
         return result;
     }
 
@@ -85,8 +110,10 @@ public class ActionState {
         return "TaskActionState{"
                 + "taskEvent="
                 + taskEvent
-                + ", memoryUpdates="
-                + memoryUpdates
+                + ", sensoryMemoryUpdates="
+                + sensoryMemoryUpdates
+                + ", shortTermMemoryUpdates="
+                + shortTermMemoryUpdates
                 + ", outputEvents="
                 + outputEvents
                 + '}';
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index d998af6..e02ec45 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -44,23 +44,28 @@ import java.util.Map;
 public class RunnerContextImpl implements RunnerContext {
 
     protected final List<Event> pendingEvents = new ArrayList<>();
-    protected final CachedMemoryStore store;
+    protected final CachedMemoryStore sensoryMemStore;
+    protected final CachedMemoryStore shortTermMemStore;
     protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
     protected final Runnable mailboxThreadChecker;
     protected final AgentPlan agentPlan;
-    protected final List<MemoryUpdate> memoryUpdates;
+    protected final List<MemoryUpdate> sensoryMemoryUpdates;
+    protected final List<MemoryUpdate> shortTermMemoryUpdates;
     protected String actionName;
 
     public RunnerContextImpl(
-            CachedMemoryStore store,
+            CachedMemoryStore sensoryMemStore,
+            CachedMemoryStore shortTermMemStore,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
             AgentPlan agentPlan) {
-        this.store = store;
+        this.sensoryMemStore = sensoryMemStore;
+        this.shortTermMemStore = shortTermMemStore;
         this.agentMetricGroup = agentMetricGroup;
         this.mailboxThreadChecker = mailboxThreadChecker;
         this.agentPlan = agentPlan;
-        this.memoryUpdates = new LinkedList<>();
+        this.sensoryMemoryUpdates = new LinkedList<>();
+        this.shortTermMemoryUpdates = new LinkedList<>();
     }
 
     public void setActionName(String actionName) {
@@ -105,6 +110,11 @@ public class RunnerContextImpl implements RunnerContext {
                 this.pendingEvents.isEmpty(), "There are pending events 
remaining in the context.");
     }
 
+    public List<MemoryUpdate> getSensoryMemoryUpdates() {
+        mailboxThreadChecker.run();
+        return List.copyOf(sensoryMemoryUpdates);
+    }
+
     /**
      * Gets all the updates made to this MemoryObject since it was created or 
the last time this
      * method was called. This method lives here because it is internally used 
by the ActionTask to
@@ -112,16 +122,29 @@ public class RunnerContextImpl implements RunnerContext {
      *
      * @return list of memory updates
      */
-    public List<MemoryUpdate> getAllMemoryUpdates() {
+    public List<MemoryUpdate> getShortTermMemoryUpdates() {
+        mailboxThreadChecker.run();
+        return List.copyOf(shortTermMemoryUpdates);
+    }
+
+    @Override
+    public MemoryObject getSensoryMemory() throws Exception {
         mailboxThreadChecker.run();
-        return List.copyOf(memoryUpdates);
+        return new MemoryObjectImpl(
+                sensoryMemStore,
+                MemoryObjectImpl.ROOT_KEY,
+                mailboxThreadChecker,
+                sensoryMemoryUpdates);
     }
 
     @Override
     public MemoryObject getShortTermMemory() throws Exception {
         mailboxThreadChecker.run();
         return new MemoryObjectImpl(
-                store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker, 
memoryUpdates);
+                shortTermMemStore,
+                MemoryObjectImpl.ROOT_KEY,
+                mailboxThreadChecker,
+                shortTermMemoryUpdates);
     }
 
     @Override
@@ -152,6 +175,11 @@ public class RunnerContextImpl implements RunnerContext {
     }
 
     public void persistMemory() throws Exception {
-        store.persistCache();
+        sensoryMemStore.persistCache();
+        shortTermMemStore.persistCache();
+    }
+
+    public void clearSensoryMemory() throws Exception {
+        sensoryMemStore.clear();
     }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
index 71eb8d2..36360cd 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/CachedMemoryStore.java
@@ -57,4 +57,9 @@ public class CachedMemoryStore implements MemoryStore {
         }
         cache.clear();
     }
+
+    public void clear() throws Exception {
+        cache.clear();
+        store.clear();
+    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
index f466750..28793fb 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryStore.java
@@ -45,4 +45,7 @@ public interface MemoryStore {
      * @return true if the MemoryItem exists, false otherwise
      */
     boolean contains(String key) throws Exception;
+
+    /** Remove all the MemoryItem. */
+    void clear() throws Exception;
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 95b991b..2a54563 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -115,6 +115,8 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient StreamRecord<OUT> reusedStreamRecord;
 
+    private transient MapState<String, MemoryObjectImpl.MemoryItem> 
sensoryMemState;
+
     private transient MapState<String, MemoryObjectImpl.MemoryItem> 
shortTermMemState;
 
     // PythonActionExecutor for Python actions
@@ -182,6 +184,13 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
     public void open() throws Exception {
         super.open();
         reusedStreamRecord = new StreamRecord<>(null);
+        // init sensoryMemState
+        MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> 
sensoryMemStateDescriptor =
+                new MapStateDescriptor<>(
+                        "sensoryMemory",
+                        TypeInformation.of(String.class),
+                        TypeInformation.of(MemoryObjectImpl.MemoryItem.class));
+        sensoryMemState = 
getRuntimeContext().getMapState(sensoryMemStateDescriptor);
 
         // init shortTermMemState
         MapStateDescriptor<String, MemoryObjectImpl.MemoryItem> 
shortTermMemStateDescriptor =
@@ -397,12 +406,19 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         if (actionState != null) {
             isFinished = true;
             outputEvents = actionState.getOutputEvents();
-            for (MemoryUpdate memoryUpdate : actionState.getMemoryUpdates()) {
+            for (MemoryUpdate memoryUpdate : 
actionState.getShortTermMemoryUpdates()) {
                 actionTask
                         .getRunnerContext()
                         .getShortTermMemory()
                         .set(memoryUpdate.getPath(), memoryUpdate.getValue());
             }
+
+            for (MemoryUpdate memoryUpdate : 
actionState.getSensoryMemoryUpdates()) {
+                actionTask
+                        .getRunnerContext()
+                        .getSensoryMemory()
+                        .set(memoryUpdate.getPath(), memoryUpdate.getValue());
+            }
         } else {
             maybeInitActionState(key, sequenceNumber, actionTask.action, 
actionTask.event);
             ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke();
@@ -452,6 +468,9 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
         // 3. Process the next InputEvent or next action task
         if (currentInputEventFinished) {
+            // Clean up sensory memory when a single run finished.
+            actionTask.getRunnerContext().clearSensoryMemory();
+
             // Once all sub-events and actions related to the current 
InputEvent are completed,
             // we can proceed to process the next InputEvent.
             int removedCount = 
removeFromListState(currentProcessingKeysOpState, key);
@@ -659,6 +678,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         } else if (actionTask.action.getExec() instanceof JavaFunction) {
             runnerContext =
                     new RunnerContextImpl(
+                            new CachedMemoryStore(sensoryMemState),
                             new CachedMemoryStore(shortTermMemState),
                             metricGroup,
                             this::checkMailboxThread,
@@ -666,6 +686,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         } else if (actionTask.action.getExec() instanceof PythonFunction) {
             runnerContext =
                     new PythonRunnerContextImpl(
+                            new CachedMemoryStore(sensoryMemState),
                             new CachedMemoryStore(shortTermMemState),
                             metricGroup,
                             this::checkMailboxThread,
@@ -755,8 +776,12 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
         ActionState actionState = actionStateStore.get(key, sequenceNum, 
action, event);
 
-        for (MemoryUpdate memoryUpdate : context.getAllMemoryUpdates()) {
-            actionState.addMemoryUpdate(memoryUpdate);
+        for (MemoryUpdate memoryUpdate : context.getSensoryMemoryUpdates()) {
+            actionState.addSensoryMemoryUpdate(memoryUpdate);
+        }
+
+        for (MemoryUpdate memoryUpdate : context.getShortTermMemoryUpdates()) {
+            actionState.addShortTermMemoryUpdate(memoryUpdate);
         }
 
         for (Event outputEvent : actionTaskResult.getOutputEvents()) {
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index 89f741d..4bdb8d8 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -36,12 +36,18 @@ public class PythonRunnerContextImpl extends 
RunnerContextImpl {
     private final PythonActionExecutor pythonActionExecutor;
 
     public PythonRunnerContextImpl(
-            CachedMemoryStore store,
+            CachedMemoryStore sensoryMemStore,
+            CachedMemoryStore shortTermMemStore,
             FlinkAgentsMetricGroupImpl agentMetricGroup,
             Runnable mailboxThreadChecker,
             AgentPlan agentPlan,
             PythonActionExecutor pythonActionExecutor) {
-        super(store, agentMetricGroup, mailboxThreadChecker, agentPlan);
+        super(
+                sensoryMemStore,
+                shortTermMemStore,
+                agentMetricGroup,
+                mailboxThreadChecker,
+                agentPlan);
         this.pythonActionExecutor = pythonActionExecutor;
     }
 
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
index 4f9ca04..eac53d2 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateSerdeTest.java
@@ -40,11 +40,13 @@ public class ActionStateSerdeTest {
         OutputEvent outputEvent = new OutputEvent("test output");
         outputEvent.setAttr("outputAttr", 123);
 
-        MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test 
value");
+        MemoryUpdate sensoryMemoryUpdate = new MemoryUpdate("sm.test.path", 
"sm test value");
+        MemoryUpdate shortTermMemoryUpdate = new MemoryUpdate("stm.test.path", 
"stm test value");
 
         // Create ActionState
         ActionState originalState = new ActionState(inputEvent);
-        originalState.addMemoryUpdate(memoryUpdate);
+        originalState.addSensoryMemoryUpdate(sensoryMemoryUpdate);
+        originalState.addShortTermMemoryUpdate(shortTermMemoryUpdate);
         originalState.addEvent(outputEvent);
 
         // Test Kafka seder/deserializer
@@ -67,10 +69,16 @@ public class ActionStateSerdeTest {
         assertEquals("testValue", deserializedInputEvent.getAttr("testAttr"));
 
         // Verify memoryUpdates
-        assertEquals(1, deserializedState.getMemoryUpdates().size());
-        MemoryUpdate deserializedMemoryUpdate = 
deserializedState.getMemoryUpdates().get(0);
-        assertEquals("test.path", deserializedMemoryUpdate.getPath());
-        assertEquals("test value", deserializedMemoryUpdate.getValue());
+        assertEquals(1, deserializedState.getSensoryMemoryUpdates().size());
+        MemoryUpdate deserializedSensoryMemoryUpdate =
+                deserializedState.getSensoryMemoryUpdates().get(0);
+        assertEquals("sm.test.path", 
deserializedSensoryMemoryUpdate.getPath());
+        assertEquals("sm test value", 
deserializedSensoryMemoryUpdate.getValue());
+        assertEquals(1, deserializedState.getShortTermMemoryUpdates().size());
+        MemoryUpdate deserializedShortTermMemoryUpdate =
+                deserializedState.getShortTermMemoryUpdates().get(0);
+        assertEquals("stm.test.path", 
deserializedShortTermMemoryUpdate.getPath());
+        assertEquals("stm test value", 
deserializedShortTermMemoryUpdate.getValue());
 
         // Verify outputEvents
         assertEquals(1, deserializedState.getOutputEvents().size());
@@ -86,7 +94,8 @@ public class ActionStateSerdeTest {
         // Create ActionState with null taskEvent
         ActionState originalState = new ActionState();
         MemoryUpdate memoryUpdate = new MemoryUpdate("test.path", "test 
value");
-        originalState.addMemoryUpdate(memoryUpdate);
+        originalState.addShortTermMemoryUpdate(memoryUpdate);
+        originalState.addSensoryMemoryUpdate(memoryUpdate);
 
         // Test serialization/deserialization
         ActionStateKafkaSeder seder = new ActionStateKafkaSeder();
@@ -98,7 +107,8 @@ public class ActionStateSerdeTest {
         assertNull(deserializedState.getTaskEvent());
 
         // Verify other fields
-        assertEquals(1, deserializedState.getMemoryUpdates().size());
+        assertEquals(1, deserializedState.getSensoryMemoryUpdates().size());
+        assertEquals(1, deserializedState.getShortTermMemoryUpdates().size());
         assertEquals(0, deserializedState.getOutputEvents().size());
     }
 
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
index 780784f..46a68d9 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java
@@ -68,6 +68,11 @@ public class MemoryRefTest {
             this.memoryObject = memoryObject;
         }
 
+        @Override
+        public MemoryObject getSensoryMemory() throws Exception {
+            return memoryObject;
+        }
+
         @Override
         public MemoryObject getShortTermMemory() {
             return memoryObject;
@@ -176,7 +181,7 @@ public class MemoryRefTest {
         for (Map.Entry<String, Object> entry : testData.entrySet()) {
             MemoryRef ref = memory.set(entry.getKey(), entry.getValue());
 
-            Object resolvedValue = ref.resolve(ctx).getValue();
+            Object resolvedValue = 
ref.resolve(ctx.getShortTermMemory()).getValue();
             assertEquals(entry.getValue(), resolvedValue);
         }
     }
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index 4027589..4c58f98 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -267,10 +267,11 @@ public class ActionExecutionOperatorTest {
                     assertThat(taskEvent).isNotNull();
 
                     // Verify memory updates contain expected data
-                    if (!state.getMemoryUpdates().isEmpty()) {
+                    if (!state.getShortTermMemoryUpdates().isEmpty()) {
                         // For action1, memory should contain input + 1
-                        
assertThat(state.getMemoryUpdates().get(0).getPath()).isEqualTo("tmp");
-                        assertThat(state.getMemoryUpdates().get(0).getValue())
+                        
assertThat(state.getShortTermMemoryUpdates().get(0).getPath())
+                                .isEqualTo("tmp");
+                        
assertThat(state.getShortTermMemoryUpdates().get(0).getValue())
                                 .isEqualTo(inputValue + 1);
                     }
 

Reply via email to