This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 68e37e1a0f97389f10c0c7c1c835279deda48083
Author: WenjinXie <[email protected]>
AuthorDate: Wed Nov 26 19:31:10 2025 +0800

    fixup! [api][runtime][python] Introduce sensory memory in python.
    
    store memory type
---
 python/flink_agents/api/memory_object.py           |  6 ++++++
 python/flink_agents/api/memory_reference.py        | 25 ++++++++++++++++------
 python/flink_agents/runtime/flink_memory_object.py | 14 +++++++-----
 python/flink_agents/runtime/local_memory_object.py | 12 ++++++-----
 python/flink_agents/runtime/local_runner.py        |  6 +++---
 .../runtime/tests/test_local_memory_object.py      |  3 ++-
 .../runtime/tests/test_memory_reference.py         | 11 +++++-----
 7 files changed, 51 insertions(+), 26 deletions(-)

diff --git a/python/flink_agents/api/memory_object.py 
b/python/flink_agents/api/memory_object.py
index a250199..5e700b0 100644
--- a/python/flink_agents/api/memory_object.py
+++ b/python/flink_agents/api/memory_object.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 
#################################################################################
 from abc import ABC, abstractmethod
+from enum import Enum
 from typing import TYPE_CHECKING, Any, Dict, List, Union
 
 from pydantic import BaseModel
@@ -23,6 +24,11 @@ from pydantic import BaseModel
 if TYPE_CHECKING:
     from flink_agents.api.memory_reference import MemoryRef
 
+class MemoryType(Enum):
+    """Memory types based on MemoryObject."""
+    SENSORY = "sensory",
+    SHORT_TERM = "short_term"
+
 class MemoryObject(BaseModel, ABC):
     """Representation of an object in the short-term memory.
 
diff --git a/python/flink_agents/api/memory_reference.py 
b/python/flink_agents/api/memory_reference.py
index fd520c3..5c37793 100644
--- a/python/flink_agents/api/memory_reference.py
+++ b/python/flink_agents/api/memory_reference.py
@@ -21,44 +21,55 @@ from typing import TYPE_CHECKING, Any
 
 from pydantic import BaseModel, ConfigDict
 
+from flink_agents.api.memory_object import MemoryType
+
 if TYPE_CHECKING:
-    from flink_agents.api.memory_object import MemoryObject
+    from flink_agents.api.runner_context import RunnerContext
 
 
 class MemoryRef(BaseModel):
     """Reference to a specific data item in the Short-Term Memory."""
 
+    memory_type: MemoryType = MemoryType.SHORT_TERM
     path: str
 
     model_config = ConfigDict(frozen=True)
 
     @staticmethod
-    def create(path: str) -> MemoryRef:
+    def create(memory_type: MemoryType, path: str) -> MemoryRef:
         """Create a new MemoryRef instance based on the given path.
 
         Parameters
         ----------
         path: str
             The absolute path of the data in the Short-Term Memory.
+        memory_type:
+            The type of the memory object this reference points to.
 
         Returns:
         -------
         MemoryRef
             A new MemoryRef instance.
         """
-        return MemoryRef(path=path)
+        return MemoryRef(memory_type=memory_type, path=path)
 
-    def resolve(self, memory: MemoryObject) -> Any:
+    def resolve(self, ctx: RunnerContext) -> Any:
         """Resolve the reference to get the actual data.
 
         Parameters
         ----------
-        memory: MemoryObject
-            The memory object this ref points to.
+        ctx: RunnerContext
+            The current execution context, used to access Short-Term Memory.
 
         Returns:
         -------
         Any
             The deserialized, original data object.
         """
-        return memory.get(self)
+        if self.memory_type == MemoryType.SENSORY:
+            return ctx.sensory_memory.get(self)
+        elif self.memory_type == MemoryType.SHORT_TERM:
+            return ctx.short_term_memory.get(self)
+        else:
+            msg = f"Unknown memory type: {self.memory_type}"
+            raise RuntimeError(msg)
diff --git a/python/flink_agents/runtime/flink_memory_object.py 
b/python/flink_agents/runtime/flink_memory_object.py
index 0b166a1..f0a8e9d 100644
--- a/python/flink_agents/runtime/flink_memory_object.py
+++ b/python/flink_agents/runtime/flink_memory_object.py
@@ -17,7 +17,7 @@
 
#################################################################################
 from typing import Any, Dict, List
 
-from flink_agents.api.memory_object import MemoryObject
+from flink_agents.api.memory_object import MemoryObject, MemoryType
 from flink_agents.api.memory_reference import MemoryRef
 
 
@@ -29,9 +29,13 @@ class FlinkMemoryObject(MemoryObject):
     memory implemented in Java.
     """
 
-    def __init__(self, j_memory_object: Any) -> None:
+    __type: MemoryType
+
+    def __init__(self, type: MemoryType, j_memory_object: Any, /, **data: Any) 
-> None:
         """Initialize with a Java MemoryObject instance."""
+        super().__init__(**data)
         self._j_memory_object = j_memory_object
+        self.__type = type
 
     def get(self, path_or_ref: str | MemoryRef) -> Any:
         """Get a nested object or value by path or MemoryRef.
@@ -51,7 +55,7 @@ class FlinkMemoryObject(MemoryObject):
             if j_result is None:
                 return None
             if j_result.isNestedObject():
-                return FlinkMemoryObject(j_result)
+                return FlinkMemoryObject(self.__type, j_result)
             else:
                 return j_result.getValue()
         except Exception as e:
@@ -62,7 +66,7 @@ class FlinkMemoryObject(MemoryObject):
         """Set a value at the given path. Creates intermediate objects if 
needed."""
         try:
             j_ref = self._j_memory_object.set(path, value)
-            return MemoryRef(path=j_ref.getPath())
+            return MemoryRef.create(memory_type=self.__type, 
path=j_ref.getPath())
         except Exception as e:
             msg = f"Failed to set value at path '{path}'"
             raise MemoryObjectError(msg) from e
@@ -70,7 +74,7 @@ class FlinkMemoryObject(MemoryObject):
     def new_object(self, path: str, *, overwrite: bool = False) -> 
"FlinkMemoryObject":
         """Create a new object at the given path."""
         try:
-            return FlinkMemoryObject(self._j_memory_object.newObject(path, 
overwrite))
+            return FlinkMemoryObject(self.__type, 
self._j_memory_object.newObject(path, overwrite))
         except Exception as e:
             msg = f"Failed to create new object at path '{path}'"
             raise MemoryObjectError(msg) from e
diff --git a/python/flink_agents/runtime/local_memory_object.py 
b/python/flink_agents/runtime/local_memory_object.py
index a1b4c1b..20e03f6 100644
--- a/python/flink_agents/runtime/local_memory_object.py
+++ b/python/flink_agents/runtime/local_memory_object.py
@@ -17,7 +17,7 @@
 
#################################################################################
 from typing import Any, ClassVar, Dict, List
 
-from flink_agents.api.memory_object import MemoryObject
+from flink_agents.api.memory_object import MemoryObject, MemoryType
 from flink_agents.api.memory_reference import MemoryRef
 
 
@@ -33,10 +33,11 @@ class LocalMemoryObject(MemoryObject):
     __SEPARATOR: ClassVar[str] = "."
     __NESTED_MARK: ClassVar[str] = "NestedObject"
 
+    __type: MemoryType
     __store: dict[str, Any]
     __prefix: str
 
-    def __init__(self, store: Dict[str, Any], prefix: str = ROOT_KEY) -> None:
+    def __init__(self, type: MemoryType, store: Dict[str, Any], prefix: str = 
ROOT_KEY) -> None:
         """Initialize a LocalMemoryObject.
 
         Parameters
@@ -48,6 +49,7 @@ class LocalMemoryObject(MemoryObject):
             shared store.
         """
         super().__init__()
+        self.__type = type
         self.__store = store if store is not None else {}
         self.__prefix = prefix
 
@@ -79,7 +81,7 @@ class LocalMemoryObject(MemoryObject):
         if abs_path in self.__store:
             value = self.__store[abs_path]
             if self._is_nested_object(value):
-                return LocalMemoryObject(self.__store, abs_path)
+                return LocalMemoryObject(self.__type, self.__store, abs_path)
             return value
         return None
 
@@ -115,7 +117,7 @@ class LocalMemoryObject(MemoryObject):
         self._add_subfield(parent_path, parts[-1])
 
         self.__store[abs_path] = value
-        return MemoryRef(path=abs_path)
+        return MemoryRef(memory_type=self.__type, path=abs_path)
 
     def new_object(self, path: str, *, overwrite: bool = False) -> 
"LocalMemoryObject":
         """Create a new object as the value of an indirect field in the object.
@@ -146,7 +148,7 @@ class LocalMemoryObject(MemoryObject):
                 raise ValueError(msg)
 
         self.__store[abs_path] = _ObjMarker()
-        return LocalMemoryObject(self.__store, abs_path)
+        return LocalMemoryObject(self.__type, self.__store, abs_path)
 
     def is_exist(self, path: str) -> bool:
         """Check whether a (direct or indirect) field exist in the object.
diff --git a/python/flink_agents/runtime/local_runner.py 
b/python/flink_agents/runtime/local_runner.py
index 9109191..6b5f50b 100644
--- a/python/flink_agents/runtime/local_runner.py
+++ b/python/flink_agents/runtime/local_runner.py
@@ -24,7 +24,7 @@ from typing_extensions import override
 
 from flink_agents.api.agent import Agent
 from flink_agents.api.events.event import Event, InputEvent, OutputEvent
-from flink_agents.api.memory_object import MemoryObject
+from flink_agents.api.memory_object import MemoryObject, MemoryType
 from flink_agents.api.metric_group import MetricGroup
 from flink_agents.api.resource import Resource, ResourceType
 from flink_agents.api.runner_context import RunnerContext
@@ -81,10 +81,10 @@ class LocalRunnerContext(RunnerContext):
         self._sensory_mem_store = {}
         self._short_term_mem_store = {}
         self._sensory_memory = LocalMemoryObject(
-            self._sensory_mem_store, LocalMemoryObject.ROOT_KEY
+            MemoryType.SENSORY, self._sensory_mem_store, 
LocalMemoryObject.ROOT_KEY
         )
         self._short_term_memory = LocalMemoryObject(
-            self._short_term_mem_store, LocalMemoryObject.ROOT_KEY
+            MemoryType.SHORT_TERM, self._short_term_mem_store, 
LocalMemoryObject.ROOT_KEY
         )
         self._config = config
 
diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py 
b/python/flink_agents/runtime/tests/test_local_memory_object.py
index 0ab6f75..305aba2 100644
--- a/python/flink_agents/runtime/tests/test_local_memory_object.py
+++ b/python/flink_agents/runtime/tests/test_local_memory_object.py
@@ -17,12 +17,13 @@
 
#################################################################################
 from typing import Dict, List, Set
 
+from flink_agents.api.memory_object import MemoryType
 from flink_agents.runtime.local_memory_object import LocalMemoryObject
 
 
 def create_memory() -> LocalMemoryObject:
     """Return a MemoryObject for every test case."""
-    return LocalMemoryObject({})
+    return LocalMemoryObject(MemoryType.SHORT_TERM, {})
 
 
 class User:  # noqa: D101
diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py 
b/python/flink_agents/runtime/tests/test_memory_reference.py
index 7684565..497de1a 100644
--- a/python/flink_agents/runtime/tests/test_memory_reference.py
+++ b/python/flink_agents/runtime/tests/test_memory_reference.py
@@ -15,6 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
+from flink_agents.api.memory_object import MemoryType
 from flink_agents.api.memory_reference import MemoryRef
 from flink_agents.runtime.local_memory_object import LocalMemoryObject
 
@@ -31,7 +32,7 @@ class MockRunnerContext:  # noqa D101
 
 def create_memory() -> LocalMemoryObject:
     """Return a MemoryObject for every test case."""
-    return LocalMemoryObject({})
+    return LocalMemoryObject(MemoryType.SHORT_TERM, {})
 
 
 class User:  # noqa: D101
@@ -73,7 +74,7 @@ def test_set_get_involved_ref() -> None:  # noqa: D103
 
 def test_memory_ref_create() -> None:  # noqa: D103
     path = "a.b.c"
-    ref = MemoryRef.create(path)
+    ref = MemoryRef.create(MemoryType.SHORT_TERM, path)
 
     assert isinstance(ref, MemoryRef)
     assert ref.path == path
@@ -95,7 +96,7 @@ def test_memory_ref_resolve() -> None:  # noqa: D103
 
     for path, value in test_data.items():
         ref = mem.set(path, value)
-        resolved_value = ref.resolve(ctx.short_term_memory)
+        resolved_value = ref.resolve(ctx)
         assert resolved_value == value
 
 
@@ -104,7 +105,7 @@ def test_get_with_ref_to_nested_object() -> None:  # noqa: 
D103
     obj = mem.new_object("a.b")
     obj.set("c", 10)
 
-    ref = MemoryRef.create("a")
+    ref = MemoryRef.create(MemoryType.SHORT_TERM, "a")
 
     resolved_obj = mem.get(ref)
     assert isinstance(resolved_obj, LocalMemoryObject)
@@ -114,7 +115,7 @@ def test_get_with_ref_to_nested_object() -> None:  # noqa: 
D103
 def test_get_with_non_existent_ref() -> None:  # noqa: D103
     mem = create_memory()
 
-    non_existent_ref = MemoryRef.create("this.path.does.not.exist")
+    non_existent_ref = MemoryRef.create(MemoryType.SHORT_TERM, 
"this.path.does.not.exist")
 
     assert mem.get(non_existent_ref) is None
 

Reply via email to