This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 68e37e1a0f97389f10c0c7c1c835279deda48083 Author: WenjinXie <[email protected]> AuthorDate: Wed Nov 26 19:31:10 2025 +0800 fixup! [api][runtime][python] Introduce sensory memory in python. store memory type --- python/flink_agents/api/memory_object.py | 6 ++++++ python/flink_agents/api/memory_reference.py | 25 ++++++++++++++++------ python/flink_agents/runtime/flink_memory_object.py | 14 +++++++----- python/flink_agents/runtime/local_memory_object.py | 12 ++++++----- python/flink_agents/runtime/local_runner.py | 6 +++--- .../runtime/tests/test_local_memory_object.py | 3 ++- .../runtime/tests/test_memory_reference.py | 11 +++++----- 7 files changed, 51 insertions(+), 26 deletions(-) diff --git a/python/flink_agents/api/memory_object.py b/python/flink_agents/api/memory_object.py index a250199..5e700b0 100644 --- a/python/flink_agents/api/memory_object.py +++ b/python/flink_agents/api/memory_object.py @@ -16,6 +16,7 @@ # limitations under the License. ################################################################################# from abc import ABC, abstractmethod +from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import BaseModel @@ -23,6 +24,11 @@ from pydantic import BaseModel if TYPE_CHECKING: from flink_agents.api.memory_reference import MemoryRef +class MemoryType(Enum): + """Memory types based on MemoryObject.""" + SENSORY = "sensory", + SHORT_TERM = "short_term" + class MemoryObject(BaseModel, ABC): """Representation of an object in the short-term memory. diff --git a/python/flink_agents/api/memory_reference.py b/python/flink_agents/api/memory_reference.py index fd520c3..5c37793 100644 --- a/python/flink_agents/api/memory_reference.py +++ b/python/flink_agents/api/memory_reference.py @@ -21,44 +21,55 @@ from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict +from flink_agents.api.memory_object import MemoryType + if TYPE_CHECKING: - from flink_agents.api.memory_object import MemoryObject + from flink_agents.api.runner_context import RunnerContext class MemoryRef(BaseModel): """Reference to a specific data item in the Short-Term Memory.""" + memory_type: MemoryType = MemoryType.SHORT_TERM path: str model_config = ConfigDict(frozen=True) @staticmethod - def create(path: str) -> MemoryRef: + def create(memory_type: MemoryType, path: str) -> MemoryRef: """Create a new MemoryRef instance based on the given path. Parameters ---------- path: str The absolute path of the data in the Short-Term Memory. + memory_type: + The type of the memory object this reference points to. Returns: ------- MemoryRef A new MemoryRef instance. """ - return MemoryRef(path=path) + return MemoryRef(memory_type=memory_type, path=path) - def resolve(self, memory: MemoryObject) -> Any: + def resolve(self, ctx: RunnerContext) -> Any: """Resolve the reference to get the actual data. Parameters ---------- - memory: MemoryObject - The memory object this ref points to. + ctx: RunnerContext + The current execution context, used to access Short-Term Memory. Returns: ------- Any The deserialized, original data object. """ - return memory.get(self) + if self.memory_type == MemoryType.SENSORY: + return ctx.sensory_memory.get(self) + elif self.memory_type == MemoryType.SHORT_TERM: + return ctx.short_term_memory.get(self) + else: + msg = f"Unknown memory type: {self.memory_type}" + raise RuntimeError(msg) diff --git a/python/flink_agents/runtime/flink_memory_object.py b/python/flink_agents/runtime/flink_memory_object.py index 0b166a1..f0a8e9d 100644 --- a/python/flink_agents/runtime/flink_memory_object.py +++ b/python/flink_agents/runtime/flink_memory_object.py @@ -17,7 +17,7 @@ ################################################################################# from typing import Any, Dict, List -from flink_agents.api.memory_object import MemoryObject +from flink_agents.api.memory_object import MemoryObject, MemoryType from flink_agents.api.memory_reference import MemoryRef @@ -29,9 +29,13 @@ class FlinkMemoryObject(MemoryObject): memory implemented in Java. """ - def __init__(self, j_memory_object: Any) -> None: + __type: MemoryType + + def __init__(self, type: MemoryType, j_memory_object: Any, /, **data: Any) -> None: """Initialize with a Java MemoryObject instance.""" + super().__init__(**data) self._j_memory_object = j_memory_object + self.__type = type def get(self, path_or_ref: str | MemoryRef) -> Any: """Get a nested object or value by path or MemoryRef. @@ -51,7 +55,7 @@ class FlinkMemoryObject(MemoryObject): if j_result is None: return None if j_result.isNestedObject(): - return FlinkMemoryObject(j_result) + return FlinkMemoryObject(self.__type, j_result) else: return j_result.getValue() except Exception as e: @@ -62,7 +66,7 @@ class FlinkMemoryObject(MemoryObject): """Set a value at the given path. Creates intermediate objects if needed.""" try: j_ref = self._j_memory_object.set(path, value) - return MemoryRef(path=j_ref.getPath()) + return MemoryRef.create(memory_type=self.__type, path=j_ref.getPath()) except Exception as e: msg = f"Failed to set value at path '{path}'" raise MemoryObjectError(msg) from e @@ -70,7 +74,7 @@ class FlinkMemoryObject(MemoryObject): def new_object(self, path: str, *, overwrite: bool = False) -> "FlinkMemoryObject": """Create a new object at the given path.""" try: - return FlinkMemoryObject(self._j_memory_object.newObject(path, overwrite)) + return FlinkMemoryObject(self.__type, self._j_memory_object.newObject(path, overwrite)) except Exception as e: msg = f"Failed to create new object at path '{path}'" raise MemoryObjectError(msg) from e diff --git a/python/flink_agents/runtime/local_memory_object.py b/python/flink_agents/runtime/local_memory_object.py index a1b4c1b..20e03f6 100644 --- a/python/flink_agents/runtime/local_memory_object.py +++ b/python/flink_agents/runtime/local_memory_object.py @@ -17,7 +17,7 @@ ################################################################################# from typing import Any, ClassVar, Dict, List -from flink_agents.api.memory_object import MemoryObject +from flink_agents.api.memory_object import MemoryObject, MemoryType from flink_agents.api.memory_reference import MemoryRef @@ -33,10 +33,11 @@ class LocalMemoryObject(MemoryObject): __SEPARATOR: ClassVar[str] = "." __NESTED_MARK: ClassVar[str] = "NestedObject" + __type: MemoryType __store: dict[str, Any] __prefix: str - def __init__(self, store: Dict[str, Any], prefix: str = ROOT_KEY) -> None: + def __init__(self, type: MemoryType, store: Dict[str, Any], prefix: str = ROOT_KEY) -> None: """Initialize a LocalMemoryObject. Parameters @@ -48,6 +49,7 @@ class LocalMemoryObject(MemoryObject): shared store. """ super().__init__() + self.__type = type self.__store = store if store is not None else {} self.__prefix = prefix @@ -79,7 +81,7 @@ class LocalMemoryObject(MemoryObject): if abs_path in self.__store: value = self.__store[abs_path] if self._is_nested_object(value): - return LocalMemoryObject(self.__store, abs_path) + return LocalMemoryObject(self.__type, self.__store, abs_path) return value return None @@ -115,7 +117,7 @@ class LocalMemoryObject(MemoryObject): self._add_subfield(parent_path, parts[-1]) self.__store[abs_path] = value - return MemoryRef(path=abs_path) + return MemoryRef(memory_type=self.__type, path=abs_path) def new_object(self, path: str, *, overwrite: bool = False) -> "LocalMemoryObject": """Create a new object as the value of an indirect field in the object. @@ -146,7 +148,7 @@ class LocalMemoryObject(MemoryObject): raise ValueError(msg) self.__store[abs_path] = _ObjMarker() - return LocalMemoryObject(self.__store, abs_path) + return LocalMemoryObject(self.__type, self.__store, abs_path) def is_exist(self, path: str) -> bool: """Check whether a (direct or indirect) field exist in the object. diff --git a/python/flink_agents/runtime/local_runner.py b/python/flink_agents/runtime/local_runner.py index 9109191..6b5f50b 100644 --- a/python/flink_agents/runtime/local_runner.py +++ b/python/flink_agents/runtime/local_runner.py @@ -24,7 +24,7 @@ from typing_extensions import override from flink_agents.api.agent import Agent from flink_agents.api.events.event import Event, InputEvent, OutputEvent -from flink_agents.api.memory_object import MemoryObject +from flink_agents.api.memory_object import MemoryObject, MemoryType from flink_agents.api.metric_group import MetricGroup from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.runner_context import RunnerContext @@ -81,10 +81,10 @@ class LocalRunnerContext(RunnerContext): self._sensory_mem_store = {} self._short_term_mem_store = {} self._sensory_memory = LocalMemoryObject( - self._sensory_mem_store, LocalMemoryObject.ROOT_KEY + MemoryType.SENSORY, self._sensory_mem_store, LocalMemoryObject.ROOT_KEY ) self._short_term_memory = LocalMemoryObject( - self._short_term_mem_store, LocalMemoryObject.ROOT_KEY + MemoryType.SHORT_TERM, self._short_term_mem_store, LocalMemoryObject.ROOT_KEY ) self._config = config diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py b/python/flink_agents/runtime/tests/test_local_memory_object.py index 0ab6f75..305aba2 100644 --- a/python/flink_agents/runtime/tests/test_local_memory_object.py +++ b/python/flink_agents/runtime/tests/test_local_memory_object.py @@ -17,12 +17,13 @@ ################################################################################# from typing import Dict, List, Set +from flink_agents.api.memory_object import MemoryType from flink_agents.runtime.local_memory_object import LocalMemoryObject def create_memory() -> LocalMemoryObject: """Return a MemoryObject for every test case.""" - return LocalMemoryObject({}) + return LocalMemoryObject(MemoryType.SHORT_TERM, {}) class User: # noqa: D101 diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py b/python/flink_agents/runtime/tests/test_memory_reference.py index 7684565..497de1a 100644 --- a/python/flink_agents/runtime/tests/test_memory_reference.py +++ b/python/flink_agents/runtime/tests/test_memory_reference.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +from flink_agents.api.memory_object import MemoryType from flink_agents.api.memory_reference import MemoryRef from flink_agents.runtime.local_memory_object import LocalMemoryObject @@ -31,7 +32,7 @@ class MockRunnerContext: # noqa D101 def create_memory() -> LocalMemoryObject: """Return a MemoryObject for every test case.""" - return LocalMemoryObject({}) + return LocalMemoryObject(MemoryType.SHORT_TERM, {}) class User: # noqa: D101 @@ -73,7 +74,7 @@ def test_set_get_involved_ref() -> None: # noqa: D103 def test_memory_ref_create() -> None: # noqa: D103 path = "a.b.c" - ref = MemoryRef.create(path) + ref = MemoryRef.create(MemoryType.SHORT_TERM, path) assert isinstance(ref, MemoryRef) assert ref.path == path @@ -95,7 +96,7 @@ def test_memory_ref_resolve() -> None: # noqa: D103 for path, value in test_data.items(): ref = mem.set(path, value) - resolved_value = ref.resolve(ctx.short_term_memory) + resolved_value = ref.resolve(ctx) assert resolved_value == value @@ -104,7 +105,7 @@ def test_get_with_ref_to_nested_object() -> None: # noqa: D103 obj = mem.new_object("a.b") obj.set("c", 10) - ref = MemoryRef.create("a") + ref = MemoryRef.create(MemoryType.SHORT_TERM, "a") resolved_obj = mem.get(ref) assert isinstance(resolved_obj, LocalMemoryObject) @@ -114,7 +115,7 @@ def test_get_with_ref_to_nested_object() -> None: # noqa: D103 def test_get_with_non_existent_ref() -> None: # noqa: D103 mem = create_memory() - non_existent_ref = MemoryRef.create("this.path.does.not.exist") + non_existent_ref = MemoryRef.create(MemoryType.SHORT_TERM, "this.path.does.not.exist") assert mem.get(non_existent_ref) is None
