This is an automated email from the ASF dual-hosted git repository.

wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new d7f89114 [python] Reject non-checkpoint-stable Python memory values at 
set() (#839)
d7f89114 is described below

commit d7f891149522100c9243f262ad01c08c19a53ca3
Author: Weiqing Yang <[email protected]>
AuthorDate: Fri Jun 12 00:19:00 2026 -0700

    [python] Reject non-checkpoint-stable Python memory values at set() (#839)
---
 python/flink_agents/api/memory_object.py           |  60 ++++++++++++
 .../flink_integration_agent.py                     |   6 +-
 .../e2e_tests_integration/workflow_test.py         |   9 +-
 python/flink_agents/runtime/flink_memory_object.py |   7 +-
 python/flink_agents/runtime/local_memory_object.py |   7 +-
 .../runtime/tests/test_local_memory_object.py      |  26 +----
 .../runtime/tests/test_memory_reference.py         |  18 ----
 .../runtime/tests/test_memory_value_validation.py  | 105 +++++++++++++++++++++
 8 files changed, 187 insertions(+), 51 deletions(-)

diff --git a/python/flink_agents/api/memory_object.py 
b/python/flink_agents/api/memory_object.py
index 2361abeb..0a0a25ca 100644
--- a/python/flink_agents/api/memory_object.py
+++ b/python/flink_agents/api/memory_object.py
@@ -25,6 +25,66 @@ if TYPE_CHECKING:
     from flink_agents.api.memory_reference import MemoryRef
 
 
+# Exact builtin types Pemja materializes into native, checkpoint-stable JVM 
values.
+# Exact-type (not isinstance): a str/int Enum or numpy scalar is a subclass 
that Pemja
+# PyObject-wraps despite passing isinstance — accepting it would defeat the 
validator.
+_CHECKPOINT_STABLE_SCALARS = (bool, int, float, str)
+
+
+def validate_memory_value(path: str, value: Any) -> None:
+    """Reject memory values that are not recursively checkpoint-stable.
+
+    Python memory values cross the Pemja boundary into Flink state. Only 
values Pemja
+    materializes into native JVM types survive checkpoint and restore; 
anything else is
+    stored as a stale PyObject wrapper and crashes on restore. Raises 
TypeError with a
+    clear, actionable message naming the offending location, type, and a 
conversion.
+
+    Parameters
+    ----------
+    path: str
+        The memory path the value is being set at, used to build the error 
breadcrumb.
+    value: Any
+        The value to validate. Must be recursively composed of None, bool, 
int, float,
+        str, list, or dict with str keys.
+    """
+    _validate(value, f"value at memory path {path!r}")
+
+
+def _validate(value: Any, where: str) -> None:
+    if value is None or type(value) in _CHECKPOINT_STABLE_SCALARS:
+        return
+    if isinstance(value, MemoryObject):
+        msg = (
+            f"{where} is a MemoryObject; use new_object(...) to store a nested 
object "
+            f"instead of passing it to set()."
+        )
+        raise TypeError(msg)
+    if type(value) is list:
+        for i, item in enumerate(value):
+            _validate(item, f"{where}[{i}]")
+        return
+    if type(value) is dict:
+        for key, val in value.items():
+            if type(key) is not str:
+                msg = (
+                    f"{where} has a non-str key {key!r} 
({type(key).__name__}); memory "
+                    f"dict keys must be str. Convert with "
+                    f"{{str(k): v for k, v in value.items()}}."
+                )
+                raise TypeError(msg)
+            _validate(val, f"{where}[{key!r}]")
+        return
+    msg = (
+        f"{where} has type {type(value).__name__!r}, which is not 
checkpoint-stable. "
+        f"Python memory values must be recursively composed of None, bool, 
int, float, "
+        f"str, list, or dict with str keys, because they cross the Pemja 
boundary into "
+        f"Flink state and non-primitive objects cannot be safely 
checkpointed/restored. "
+        f"Materialize it first, e.g. str(value) for a UUID, 
value.model_dump(mode='json')"
+        f" for a Pydantic model, or list(value) for a tuple/set."
+    )
+    raise TypeError(msg)
+
+
 class MemoryType(Enum):
     """Memory types based on MemoryObject."""
 
diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
 
b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
index 2982c5aa..744eb9d9 100644
--- 
a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
+++ 
b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
@@ -127,7 +127,9 @@ class DataStreamAgent(Agent):
         content.review += " first action, log success=" + str(log_success) + 
","
         content.memory_info = {"total_reviews": total}
 
-        data_ref = stm.set(f"processed_items.item_{content.id}", content)
+        data_ref = stm.set(
+            f"processed_items.item_{content.id}", 
content.model_dump(mode="json")
+        )
         ctx.send_event(MyEvent(value=data_ref))
 
     @action(MyEvent.EVENT_TYPE)
@@ -135,7 +137,7 @@ class DataStreamAgent(Agent):
     def second_action(event: Event, ctx: RunnerContext) -> None:
         input_data = MyEvent.from_event(event).value
         stm = ctx.short_term_memory
-        resolved_data: ItemData = stm.get(input_data)
+        resolved_data = ItemData.model_validate(stm.get(input_data))
 
         content = copy.deepcopy(resolved_data)
         content.review += " second action"
diff --git 
a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py 
b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
index ab71fa7b..2eafea7c 100644
--- a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
+++ b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
@@ -76,12 +76,13 @@ class MyAgent(Agent):
         memory = ctx.short_term_memory
 
         data_path = f"user_data.{key}"
-        previous_data: ProcessedData = memory.get(data_path)
+        stored = memory.get(data_path)
+        previous_data = ProcessedData.model_validate(stored) if stored else 
None
         current_count = previous_data.visit_count if previous_data else 0
         new_count = current_count + 1
 
         data_to_store = ProcessedData(content=input_message, 
visit_count=new_count)
-        data_ref = memory.set(data_path, data_to_store)
+        data_ref = memory.set(data_path, data_to_store.model_dump(mode="json"))
 
         ctx.send_event(MyEvent(value=data_ref))
 
@@ -95,7 +96,7 @@ class MyAgent(Agent):
         content_ref: MemoryRef = MyEvent.from_event(event).value
         memory = ctx.short_term_memory
 
-        processed_data: ProcessedData = memory.get(content_ref)
+        processed_data = ProcessedData.model_validate(memory.get(content_ref))
 
         base_message = processed_data.content
         current_count = processed_data.visit_count
@@ -104,7 +105,7 @@ class MyAgent(Agent):
         updated_data_to_store = ProcessedData(
             content=base_message, visit_count=new_count
         )
-        memory.set(content_ref.path, updated_data_to_store)
+        memory.set(content_ref.path, 
updated_data_to_store.model_dump(mode="json"))
 
         final_content = f"{base_message} -> processed by second_action"
         key_with_count = f"(visit {new_count} times)"
diff --git a/python/flink_agents/runtime/flink_memory_object.py 
b/python/flink_agents/runtime/flink_memory_object.py
index c2840665..6143bc27 100644
--- a/python/flink_agents/runtime/flink_memory_object.py
+++ b/python/flink_agents/runtime/flink_memory_object.py
@@ -17,7 +17,11 @@
 
#################################################################################
 from typing import Any, Dict, List
 
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import (
+    MemoryObject,
+    MemoryType,
+    validate_memory_value,
+)
 from flink_agents.api.memory_reference import MemoryRef
 
 
@@ -66,6 +70,7 @@ class FlinkMemoryObject(MemoryObject):
 
     def set(self, path: str, value: Any) -> MemoryRef:
         """Set a value at the given path. Creates intermediate objects if 
needed."""
+        validate_memory_value(path, value)
         try:
             j_ref = self._j_memory_object.set(path, value)
             return MemoryRef.create(memory_type=self.__type, 
path=j_ref.getPath())
diff --git a/python/flink_agents/runtime/local_memory_object.py 
b/python/flink_agents/runtime/local_memory_object.py
index a9fb44b8..1a50dfd8 100644
--- a/python/flink_agents/runtime/local_memory_object.py
+++ b/python/flink_agents/runtime/local_memory_object.py
@@ -17,7 +17,11 @@
 
#################################################################################
 from typing import Any, ClassVar, Dict, List
 
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import (
+    MemoryObject,
+    MemoryType,
+    validate_memory_value,
+)
 from flink_agents.api.memory_reference import MemoryRef
 
 
@@ -106,6 +110,7 @@ class LocalMemoryObject(MemoryObject):
         if isinstance(value, LocalMemoryObject):
             msg = "Do not set a MemoryObject instance directly; use 
new_object()."
             raise TypeError(msg)
+        validate_memory_value(path, value)
 
         abs_path = self._full_path(path)
 
diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py 
b/python/flink_agents/runtime/tests/test_local_memory_object.py
index 5268077a..18b56587 100644
--- a/python/flink_agents/runtime/tests/test_local_memory_object.py
+++ b/python/flink_agents/runtime/tests/test_local_memory_object.py
@@ -15,7 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-from typing import Dict, List, Set
+from typing import Dict, List
 
 from flink_agents.api.memory_object import MemoryType
 from flink_agents.runtime.local_memory_object import LocalMemoryObject
@@ -26,20 +26,6 @@ def create_memory() -> LocalMemoryObject:
     return LocalMemoryObject(MemoryType.SHORT_TERM, {})
 
 
-class User:
-    def __init__(self, name: str, age: int) -> None:
-        """Store for later comparison."""
-        self.name = name
-        self.age = age
-
-    def __eq__(self, other: object) -> bool:
-        return (
-            isinstance(other, User)
-            and other.name == self.name
-            and other.age == self.age
-        )
-
-
 def test_basic_set_get_various_types() -> None:
     mem = create_memory()
 
@@ -63,16 +49,6 @@ def test_basic_set_get_various_types() -> None:
     mem.set("dict", d)
     assert mem.get("dict") == d
 
-    # set
-    s: Set[int] = {1, 2, 3}
-    mem.set("set", s)
-    assert mem.get("set") == s
-
-    # custom object
-    user = User("Alice", 20)
-    mem.set("user", user)
-    assert mem.get("user") == user
-
 
 def test_nested_set_and_get() -> None:
     mem = create_memory()
diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py 
b/python/flink_agents/runtime/tests/test_memory_reference.py
index 10d885b1..86428d2a 100644
--- a/python/flink_agents/runtime/tests/test_memory_reference.py
+++ b/python/flink_agents/runtime/tests/test_memory_reference.py
@@ -35,20 +35,6 @@ def create_memory() -> LocalMemoryObject:
     return LocalMemoryObject(MemoryType.SHORT_TERM, {})
 
 
-class User:
-    def __init__(self, name: str, age: int) -> None:
-        """Store for later comparison."""
-        self.name = name
-        self.age = age
-
-    def __eq__(self, other: object) -> bool:
-        return (
-            isinstance(other, User)
-            and other.name == self.name
-            and other.age == self.age
-        )
-
-
 def test_set_get_involved_ref() -> None:
     mem = create_memory()
 
@@ -59,8 +45,6 @@ def test_set_get_involved_ref() -> None:
         ("my_str", "hello", "str"),
         ("my_list", ["a", "b"], "list"),
         ("my_dict", {"x": 10}, "dict"),
-        ("my_set", {1, 2, 3}, "set"),
-        ("my_user", User("Alice", 30), "User"),
     ]
 
     for path, value, _expected_type_name in test_cases:
@@ -90,8 +74,6 @@ def test_memory_ref_resolve() -> None:
         "my_str": "hello",
         "my_list": ["a", "b"],
         "my_dict": {"x": 10},
-        "my_set": {1, 2, 3},
-        "my_user": User("Charlie", 50),
     }
 
     for path, value in test_data.items():
diff --git a/python/flink_agents/runtime/tests/test_memory_value_validation.py 
b/python/flink_agents/runtime/tests/test_memory_value_validation.py
new file mode 100644
index 00000000..c6b4418a
--- /dev/null
+++ b/python/flink_agents/runtime/tests/test_memory_value_validation.py
@@ -0,0 +1,105 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import uuid
+from enum import Enum
+from unittest.mock import MagicMock
+
+import pytest
+from pydantic import BaseModel
+
+from flink_agents.api.memory_object import MemoryType, validate_memory_value
+from flink_agents.runtime.flink_memory_object import (
+    FlinkMemoryObject,
+    MemoryObjectError,
+)
+
+
+class _Model(BaseModel):
+    name: str
+
+
+class _StrEnum(str, Enum):
+    A = "a"
+
+
+class _Plain:
+    pass
+
+
+def test_accepts_none_and_scalars() -> None:
+    for value in (None, True, False, 0, 1, -3, 3.14, "", "hello"):
+        validate_memory_value("p", value)
+
+
+def test_accepts_nested_list_and_dict() -> None:
+    validate_memory_value("p", [1, "a", [2, 3], {"k": [4, None]}])
+    validate_memory_value("p", {"a": 1, "b": {"c": [True, "x"]}})
+
+
+def test_rejects_pydantic_model() -> None:
+    with pytest.raises(TypeError, match="model_dump"):
+        validate_memory_value("p", _Model(name="x"))
+
+
+def test_rejects_uuid() -> None:
+    with pytest.raises(TypeError, match=r"str\(value\)"):
+        validate_memory_value("p", uuid.uuid4())
+
+
+def test_rejects_tuple_set_frozenset() -> None:
+    for value in ((1, 2), {1, 2}, frozenset({1, 2})):
+        with pytest.raises(TypeError, match=r"list\(value\)"):
+            validate_memory_value("p", value)
+
+
+def test_rejects_str_enum() -> None:
+    # str-Enum passes isinstance(str) but is PyObject-wrapped by Pemja; the
+    # exact-type check must reject it.
+    with pytest.raises(TypeError, match="not checkpoint-stable"):
+        validate_memory_value("p", _StrEnum.A)
+
+
+def test_rejects_custom_class() -> None:
+    with pytest.raises(TypeError, match="not checkpoint-stable"):
+        validate_memory_value("p", _Plain())
+
+
+def test_rejects_non_str_dict_key() -> None:
+    with pytest.raises(TypeError, match="non-str key"):
+        validate_memory_value("p", {1: "v"})
+
+
+def test_rejects_nested_value_reports_breadcrumb() -> None:
+    with pytest.raises(TypeError, match=r"\[2\]\['bad'\]"):
+        validate_memory_value("p", [1, 2, {"bad": object()}])
+
+
+def test_memory_object_value_suggests_new_object() -> None:
+    inner = FlinkMemoryObject(MemoryType.SHORT_TERM, MagicMock())
+    with pytest.raises(TypeError, match="new_object"):
+        validate_memory_value("p", inner)
+
+
+def test_flink_set_raises_raw_type_error() -> None:
+    j_obj = MagicMock()
+    mem = FlinkMemoryObject(MemoryType.SHORT_TERM, j_obj)
+    with pytest.raises(TypeError) as exc_info:
+        mem.set("p", uuid.uuid4())
+    # Validation fires before the Java call, raising a raw TypeError.
+    assert not isinstance(exc_info.value, MemoryObjectError)
+    j_obj.set.assert_not_called()

Reply via email to