This is an automated email from the ASF dual-hosted git repository.
wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new d7f89114 [python] Reject non-checkpoint-stable Python memory values at
set() (#839)
d7f89114 is described below
commit d7f891149522100c9243f262ad01c08c19a53ca3
Author: Weiqing Yang <[email protected]>
AuthorDate: Fri Jun 12 00:19:00 2026 -0700
[python] Reject non-checkpoint-stable Python memory values at set() (#839)
---
python/flink_agents/api/memory_object.py | 60 ++++++++++++
.../flink_integration_agent.py | 6 +-
.../e2e_tests_integration/workflow_test.py | 9 +-
python/flink_agents/runtime/flink_memory_object.py | 7 +-
python/flink_agents/runtime/local_memory_object.py | 7 +-
.../runtime/tests/test_local_memory_object.py | 26 +----
.../runtime/tests/test_memory_reference.py | 18 ----
.../runtime/tests/test_memory_value_validation.py | 105 +++++++++++++++++++++
8 files changed, 187 insertions(+), 51 deletions(-)
diff --git a/python/flink_agents/api/memory_object.py
b/python/flink_agents/api/memory_object.py
index 2361abeb..0a0a25ca 100644
--- a/python/flink_agents/api/memory_object.py
+++ b/python/flink_agents/api/memory_object.py
@@ -25,6 +25,66 @@ if TYPE_CHECKING:
from flink_agents.api.memory_reference import MemoryRef
+# Exact builtin types Pemja materializes into native, checkpoint-stable JVM
values.
+# Exact-type (not isinstance): a str/int Enum or numpy scalar is a subclass
that Pemja
+# PyObject-wraps despite passing isinstance — accepting it would defeat the
validator.
+_CHECKPOINT_STABLE_SCALARS = (bool, int, float, str)
+
+
+def validate_memory_value(path: str, value: Any) -> None:
+ """Reject memory values that are not recursively checkpoint-stable.
+
+ Python memory values cross the Pemja boundary into Flink state. Only
values Pemja
+ materializes into native JVM types survive checkpoint and restore;
anything else is
+ stored as a stale PyObject wrapper and crashes on restore. Raises
TypeError with a
+ clear, actionable message naming the offending location, type, and a
conversion.
+
+ Parameters
+ ----------
+ path: str
+ The memory path the value is being set at, used to build the error
breadcrumb.
+ value: Any
+ The value to validate. Must be recursively composed of None, bool,
int, float,
+ str, list, or dict with str keys.
+ """
+ _validate(value, f"value at memory path {path!r}")
+
+
+def _validate(value: Any, where: str) -> None:
+ if value is None or type(value) in _CHECKPOINT_STABLE_SCALARS:
+ return
+ if isinstance(value, MemoryObject):
+ msg = (
+ f"{where} is a MemoryObject; use new_object(...) to store a nested
object "
+ f"instead of passing it to set()."
+ )
+ raise TypeError(msg)
+ if type(value) is list:
+ for i, item in enumerate(value):
+ _validate(item, f"{where}[{i}]")
+ return
+ if type(value) is dict:
+ for key, val in value.items():
+ if type(key) is not str:
+ msg = (
+ f"{where} has a non-str key {key!r}
({type(key).__name__}); memory "
+ f"dict keys must be str. Convert with "
+ f"{{str(k): v for k, v in value.items()}}."
+ )
+ raise TypeError(msg)
+ _validate(val, f"{where}[{key!r}]")
+ return
+ msg = (
+ f"{where} has type {type(value).__name__!r}, which is not
checkpoint-stable. "
+ f"Python memory values must be recursively composed of None, bool,
int, float, "
+ f"str, list, or dict with str keys, because they cross the Pemja
boundary into "
+ f"Flink state and non-primitive objects cannot be safely
checkpointed/restored. "
+ f"Materialize it first, e.g. str(value) for a UUID,
value.model_dump(mode='json')"
+ f" for a Pydantic model, or list(value) for a tuple/set."
+ )
+ raise TypeError(msg)
+
+
class MemoryType(Enum):
"""Memory types based on MemoryObject."""
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
index 2982c5aa..744eb9d9 100644
---
a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
+++
b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py
@@ -127,7 +127,9 @@ class DataStreamAgent(Agent):
content.review += " first action, log success=" + str(log_success) +
","
content.memory_info = {"total_reviews": total}
- data_ref = stm.set(f"processed_items.item_{content.id}", content)
+ data_ref = stm.set(
+ f"processed_items.item_{content.id}",
content.model_dump(mode="json")
+ )
ctx.send_event(MyEvent(value=data_ref))
@action(MyEvent.EVENT_TYPE)
@@ -135,7 +137,7 @@ class DataStreamAgent(Agent):
def second_action(event: Event, ctx: RunnerContext) -> None:
input_data = MyEvent.from_event(event).value
stm = ctx.short_term_memory
- resolved_data: ItemData = stm.get(input_data)
+ resolved_data = ItemData.model_validate(stm.get(input_data))
content = copy.deepcopy(resolved_data)
content.review += " second action"
diff --git
a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
index ab71fa7b..2eafea7c 100644
--- a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
+++ b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py
@@ -76,12 +76,13 @@ class MyAgent(Agent):
memory = ctx.short_term_memory
data_path = f"user_data.{key}"
- previous_data: ProcessedData = memory.get(data_path)
+ stored = memory.get(data_path)
+ previous_data = ProcessedData.model_validate(stored) if stored else
None
current_count = previous_data.visit_count if previous_data else 0
new_count = current_count + 1
data_to_store = ProcessedData(content=input_message,
visit_count=new_count)
- data_ref = memory.set(data_path, data_to_store)
+ data_ref = memory.set(data_path, data_to_store.model_dump(mode="json"))
ctx.send_event(MyEvent(value=data_ref))
@@ -95,7 +96,7 @@ class MyAgent(Agent):
content_ref: MemoryRef = MyEvent.from_event(event).value
memory = ctx.short_term_memory
- processed_data: ProcessedData = memory.get(content_ref)
+ processed_data = ProcessedData.model_validate(memory.get(content_ref))
base_message = processed_data.content
current_count = processed_data.visit_count
@@ -104,7 +105,7 @@ class MyAgent(Agent):
updated_data_to_store = ProcessedData(
content=base_message, visit_count=new_count
)
- memory.set(content_ref.path, updated_data_to_store)
+ memory.set(content_ref.path,
updated_data_to_store.model_dump(mode="json"))
final_content = f"{base_message} -> processed by second_action"
key_with_count = f"(visit {new_count} times)"
diff --git a/python/flink_agents/runtime/flink_memory_object.py
b/python/flink_agents/runtime/flink_memory_object.py
index c2840665..6143bc27 100644
--- a/python/flink_agents/runtime/flink_memory_object.py
+++ b/python/flink_agents/runtime/flink_memory_object.py
@@ -17,7 +17,11 @@
#################################################################################
from typing import Any, Dict, List
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import (
+ MemoryObject,
+ MemoryType,
+ validate_memory_value,
+)
from flink_agents.api.memory_reference import MemoryRef
@@ -66,6 +70,7 @@ class FlinkMemoryObject(MemoryObject):
def set(self, path: str, value: Any) -> MemoryRef:
"""Set a value at the given path. Creates intermediate objects if
needed."""
+ validate_memory_value(path, value)
try:
j_ref = self._j_memory_object.set(path, value)
return MemoryRef.create(memory_type=self.__type,
path=j_ref.getPath())
diff --git a/python/flink_agents/runtime/local_memory_object.py
b/python/flink_agents/runtime/local_memory_object.py
index a9fb44b8..1a50dfd8 100644
--- a/python/flink_agents/runtime/local_memory_object.py
+++ b/python/flink_agents/runtime/local_memory_object.py
@@ -17,7 +17,11 @@
#################################################################################
from typing import Any, ClassVar, Dict, List
-from flink_agents.api.memory_object import MemoryObject, MemoryType
+from flink_agents.api.memory_object import (
+ MemoryObject,
+ MemoryType,
+ validate_memory_value,
+)
from flink_agents.api.memory_reference import MemoryRef
@@ -106,6 +110,7 @@ class LocalMemoryObject(MemoryObject):
if isinstance(value, LocalMemoryObject):
msg = "Do not set a MemoryObject instance directly; use
new_object()."
raise TypeError(msg)
+ validate_memory_value(path, value)
abs_path = self._full_path(path)
diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py
b/python/flink_agents/runtime/tests/test_local_memory_object.py
index 5268077a..18b56587 100644
--- a/python/flink_agents/runtime/tests/test_local_memory_object.py
+++ b/python/flink_agents/runtime/tests/test_local_memory_object.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
-from typing import Dict, List, Set
+from typing import Dict, List
from flink_agents.api.memory_object import MemoryType
from flink_agents.runtime.local_memory_object import LocalMemoryObject
@@ -26,20 +26,6 @@ def create_memory() -> LocalMemoryObject:
return LocalMemoryObject(MemoryType.SHORT_TERM, {})
-class User:
- def __init__(self, name: str, age: int) -> None:
- """Store for later comparison."""
- self.name = name
- self.age = age
-
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, User)
- and other.name == self.name
- and other.age == self.age
- )
-
-
def test_basic_set_get_various_types() -> None:
mem = create_memory()
@@ -63,16 +49,6 @@ def test_basic_set_get_various_types() -> None:
mem.set("dict", d)
assert mem.get("dict") == d
- # set
- s: Set[int] = {1, 2, 3}
- mem.set("set", s)
- assert mem.get("set") == s
-
- # custom object
- user = User("Alice", 20)
- mem.set("user", user)
- assert mem.get("user") == user
-
def test_nested_set_and_get() -> None:
mem = create_memory()
diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py
b/python/flink_agents/runtime/tests/test_memory_reference.py
index 10d885b1..86428d2a 100644
--- a/python/flink_agents/runtime/tests/test_memory_reference.py
+++ b/python/flink_agents/runtime/tests/test_memory_reference.py
@@ -35,20 +35,6 @@ def create_memory() -> LocalMemoryObject:
return LocalMemoryObject(MemoryType.SHORT_TERM, {})
-class User:
- def __init__(self, name: str, age: int) -> None:
- """Store for later comparison."""
- self.name = name
- self.age = age
-
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, User)
- and other.name == self.name
- and other.age == self.age
- )
-
-
def test_set_get_involved_ref() -> None:
mem = create_memory()
@@ -59,8 +45,6 @@ def test_set_get_involved_ref() -> None:
("my_str", "hello", "str"),
("my_list", ["a", "b"], "list"),
("my_dict", {"x": 10}, "dict"),
- ("my_set", {1, 2, 3}, "set"),
- ("my_user", User("Alice", 30), "User"),
]
for path, value, _expected_type_name in test_cases:
@@ -90,8 +74,6 @@ def test_memory_ref_resolve() -> None:
"my_str": "hello",
"my_list": ["a", "b"],
"my_dict": {"x": 10},
- "my_set": {1, 2, 3},
- "my_user": User("Charlie", 50),
}
for path, value in test_data.items():
diff --git a/python/flink_agents/runtime/tests/test_memory_value_validation.py
b/python/flink_agents/runtime/tests/test_memory_value_validation.py
new file mode 100644
index 00000000..c6b4418a
--- /dev/null
+++ b/python/flink_agents/runtime/tests/test_memory_value_validation.py
@@ -0,0 +1,105 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#################################################################################
+import uuid
+from enum import Enum
+from unittest.mock import MagicMock
+
+import pytest
+from pydantic import BaseModel
+
+from flink_agents.api.memory_object import MemoryType, validate_memory_value
+from flink_agents.runtime.flink_memory_object import (
+ FlinkMemoryObject,
+ MemoryObjectError,
+)
+
+
+class _Model(BaseModel):
+ name: str
+
+
+class _StrEnum(str, Enum):
+ A = "a"
+
+
+class _Plain:
+ pass
+
+
+def test_accepts_none_and_scalars() -> None:
+ for value in (None, True, False, 0, 1, -3, 3.14, "", "hello"):
+ validate_memory_value("p", value)
+
+
+def test_accepts_nested_list_and_dict() -> None:
+ validate_memory_value("p", [1, "a", [2, 3], {"k": [4, None]}])
+ validate_memory_value("p", {"a": 1, "b": {"c": [True, "x"]}})
+
+
+def test_rejects_pydantic_model() -> None:
+ with pytest.raises(TypeError, match="model_dump"):
+ validate_memory_value("p", _Model(name="x"))
+
+
+def test_rejects_uuid() -> None:
+ with pytest.raises(TypeError, match=r"str\(value\)"):
+ validate_memory_value("p", uuid.uuid4())
+
+
+def test_rejects_tuple_set_frozenset() -> None:
+ for value in ((1, 2), {1, 2}, frozenset({1, 2})):
+ with pytest.raises(TypeError, match=r"list\(value\)"):
+ validate_memory_value("p", value)
+
+
+def test_rejects_str_enum() -> None:
+ # str-Enum passes isinstance(str) but is PyObject-wrapped by Pemja; the
+ # exact-type check must reject it.
+ with pytest.raises(TypeError, match="not checkpoint-stable"):
+ validate_memory_value("p", _StrEnum.A)
+
+
+def test_rejects_custom_class() -> None:
+ with pytest.raises(TypeError, match="not checkpoint-stable"):
+ validate_memory_value("p", _Plain())
+
+
+def test_rejects_non_str_dict_key() -> None:
+ with pytest.raises(TypeError, match="non-str key"):
+ validate_memory_value("p", {1: "v"})
+
+
+def test_rejects_nested_value_reports_breadcrumb() -> None:
+ with pytest.raises(TypeError, match=r"\[2\]\['bad'\]"):
+ validate_memory_value("p", [1, 2, {"bad": object()}])
+
+
+def test_memory_object_value_suggests_new_object() -> None:
+ inner = FlinkMemoryObject(MemoryType.SHORT_TERM, MagicMock())
+ with pytest.raises(TypeError, match="new_object"):
+ validate_memory_value("p", inner)
+
+
+def test_flink_set_raises_raw_type_error() -> None:
+ j_obj = MagicMock()
+ mem = FlinkMemoryObject(MemoryType.SHORT_TERM, j_obj)
+ with pytest.raises(TypeError) as exc_info:
+ mem.set("p", uuid.uuid4())
+ # Validation fires before the Java call, raising a raw TypeError.
+ assert not isinstance(exc_info.value, MemoryObjectError)
+ j_obj.set.assert_not_called()