This is an automated email from the ASF dual-hosted git repository.
wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push:
new 06e7722a [python] Normalize built-in tool-call context for checkpoint
safety (#828)
06e7722a is described below
commit 06e7722aebfc33902e9c26dea4ecf51680ae20bc
Author: Weiqing Yang <[email protected]>
AuthorDate: Wed Jun 10 03:08:45 2026 -0700
[python] Normalize built-in tool-call context for checkpoint safety (#828)
The built-in chat-model action stored non-primitive Python objects in
sensory memory: UUID values, an OutputSchema, and ChatMessage lists.
Pemja wraps such objects as PyObject holders whose JNI pointers go stale
after a TaskManager/Python restart, so restoring the checkpointed tool
context crashes in JcpPyObject_FromJObject.
Normalize these values to a primitive-only form before they reach memory
and reconstruct the rich types on read, fully inside the three
tool-context helpers (no caller or signature changes):
- ChatMessage lists are stored via model_dump(mode="json") and
reconstructed via ChatMessage.model_validate.
- initial_request_id is stored as str and reconstructed to UUID.
- output_schema is stored via OutputSchema.model_dump() and
reconstructed via OutputSchema.model_validate.
Dict keys were already strings and the retry-stats context already holds
only ints, so both are unchanged. prompt_args is user-supplied and
already round-trips as a ChatRequestEvent attribute, so it is left as-is.
---
.../flink_agents/plan/actions/chat_model_action.py | 33 ++++-
.../plan/tests/actions/test_chat_model_action.py | 157 ++++++++++++++++++++-
.../tests/actions/test_chat_model_action_retry.py | 6 +-
3 files changed, 187 insertions(+), 9 deletions(-)
diff --git a/python/flink_agents/plan/actions/chat_model_action.py
b/python/flink_agents/plan/actions/chat_model_action.py
index 188eb9e7..e4572056 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
-import copy
import json
import logging
import re
@@ -59,6 +58,16 @@ _logger = logging.getLogger(__name__)
# ============================================================================
# Helper Functions for Tool Call Context Management
# ============================================================================
+def _serialize_messages(messages: List[ChatMessage]) -> List[Dict]:
+ """Materialize chat messages into JSON-safe dicts for checkpoint-stable
memory."""
+ return [message.model_dump(mode="json") for message in messages]
+
+
+def _deserialize_messages(messages: List[Dict]) -> List[ChatMessage]:
+ """Reconstruct chat messages from their stored JSON-safe dict form."""
+ return [ChatMessage.model_validate(message) for message in messages]
+
+
def _update_tool_call_context(
sensory_memory: MemoryObject,
initial_request_id: UUID,
@@ -81,12 +90,12 @@ def _update_tool_call_context(
key = str(initial_request_id)
tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
if key not in tool_call_context and initial_messages is not None:
- tool_call_context[key] = copy.deepcopy(initial_messages)
+ tool_call_context[key] = _serialize_messages(initial_messages)
- tool_call_context[key].extend(added_messages)
+ tool_call_context[key].extend(_serialize_messages(added_messages))
sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
- return tool_call_context[key]
+ return _deserialize_messages(tool_call_context[key])
def _save_tool_request_event_context(
@@ -100,10 +109,12 @@ def _save_tool_request_event_context(
"""Save the context for a specific tool request event."""
context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
context[str(tool_request_event_id)] = {
- "initial_request_id": initial_request_id,
+ "initial_request_id": str(initial_request_id),
"model": model,
_PROMPT_ARGS: prompt_args if prompt_args is not None else {},
- "output_schema": output_schema,
+ "output_schema": output_schema.model_dump()
+ if output_schema is not None
+ else None,
}
sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context)
@@ -114,6 +125,16 @@ def _get_tool_request_event_context(
"""Get and remove the context for a specific tool request event."""
context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
removed_context = context.pop(str(request_id), {})
+ if removed_context:
+ removed_context["initial_request_id"] = UUID(
+ removed_context["initial_request_id"]
+ )
+ output_schema = removed_context["output_schema"]
+ removed_context["output_schema"] = (
+ OutputSchema.model_validate(output_schema)
+ if output_schema is not None
+ else None
+ )
return removed_context
diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action.py
b/python/flink_agents/plan/tests/actions/test_chat_model_action.py
index 72b223d2..94ff8ab9 100644
--- a/python/flink_agents/plan/tests/actions/test_chat_model_action.py
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action.py
@@ -15,7 +15,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
-from flink_agents.plan.actions.chat_model_action import _clean_llm_response
+from uuid import uuid4
+
+from pydantic import BaseModel
+from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo
+
+from flink_agents.api.agents.react_agent import OutputSchema
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.memory_object import MemoryType
+from flink_agents.plan.actions.chat_model_action import (
+ _TOOL_CALL_CONTEXT,
+ _TOOL_REQUEST_EVENT_CONTEXT,
+ _clean_llm_response,
+ _get_tool_request_event_context,
+ _save_tool_request_event_context,
+ _update_tool_call_context,
+)
+from flink_agents.runtime.local_memory_object import LocalMemoryObject
+
+
+def _memory() -> LocalMemoryObject:
+ return LocalMemoryObject(MemoryType.SHORT_TERM, {})
+
+
+def _assert_primitive(obj) -> None:
+ if obj is None or isinstance(obj, bool | int | float | str | bytes):
+ return
+ if isinstance(obj, list):
+ for item in obj:
+ _assert_primitive(item)
+ return
+ if isinstance(obj, dict):
+ for k, v in obj.items():
+ assert isinstance(k, str | int | float | bool), (
+ f"non-primitive key: {k!r}"
+ )
+ _assert_primitive(v)
+ return
+ msg = f"non-primitive value of type {type(obj).__name__}: {obj!r}"
+ raise AssertionError(msg)
+
+
+class _Result(BaseModel):
+ result: int
def test_clean_llm_response_with_json_block():
@@ -52,3 +94,116 @@ def test_clean_llm_response_with_multiple_lines_in_block():
input_str = '```json\n{\n "key": "value"\n}\n```'
expected = '{\n "key": "value"\n}'
assert _clean_llm_response(input_str) == expected
+
+
+def test_update_tool_call_context_stores_primitive_only():
+ mem = _memory()
+ initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+ added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
+ _update_tool_call_context(mem, uuid4(), initial, added)
+ _assert_primitive(mem.get(_TOOL_CALL_CONTEXT))
+
+
+def test_update_tool_call_context_returns_chat_messages():
+ mem = _memory()
+ initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+ added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
+ result = _update_tool_call_context(mem, uuid4(), initial, added)
+ assert all(isinstance(message, ChatMessage) for message in result)
+ assert [(m.role, m.content) for m in result] == [
+ (MessageRole.USER, "hi"),
+ (MessageRole.ASSISTANT, "hello"),
+ ]
+
+
+def test_tool_request_event_context_stores_primitive_only():
+ mem = _memory()
+ _save_tool_request_event_context(
+ mem,
+ uuid4(),
+ uuid4(),
+ "ollama",
+ {"k": "v"},
+ OutputSchema(output_schema=_Result),
+ )
+ _assert_primitive(mem.get(_TOOL_REQUEST_EVENT_CONTEXT))
+
+
+def test_tool_request_event_context_round_trip():
+ mem = _memory()
+ event_id = uuid4()
+ initial_request_id = uuid4()
+ _save_tool_request_event_context(
+ mem,
+ event_id,
+ initial_request_id,
+ "ollama",
+ None,
+ OutputSchema(output_schema=_Result),
+ )
+ context = _get_tool_request_event_context(mem, event_id)
+ assert context["initial_request_id"] == initial_request_id
+ assert isinstance(context["initial_request_id"], type(initial_request_id))
+ assert isinstance(context["output_schema"], OutputSchema)
+ assert context["output_schema"].output_schema is _Result
+ assert context["model"] == "ollama"
+
+
+def test_get_context_none_output_schema():
+ mem = _memory()
+ event_id = uuid4()
+ _save_tool_request_event_context(mem, event_id, uuid4(), "ollama", None,
None)
+ assert
mem.get(_TOOL_REQUEST_EVENT_CONTEXT)[str(event_id)]["output_schema"] is None
+ context = _get_tool_request_event_context(mem, event_id)
+ assert context["output_schema"] is None
+
+
+def test_request_event_key_match_after_normalization():
+ mem = _memory()
+ event_id = uuid4()
+ _save_tool_request_event_context(
+ mem, event_id, uuid4(), "ollama", None, None
+ )
+ context = _get_tool_request_event_context(mem, event_id)
+ assert context != {}
+ assert context["model"] == "ollama"
+
+
+def test_tool_call_context_key_match_after_normalization():
+ mem = _memory()
+ request_id = uuid4()
+ initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+ _update_tool_call_context(mem, request_id, initial, [])
+ extra = ChatMessage(role=MessageRole.TOOL, content="result")
+ result = _update_tool_call_context(mem, request_id, None, [extra])
+ assert len(result) == 2
+ assert len(mem.get(_TOOL_CALL_CONTEXT)[str(request_id)]) == 2
+
+
+def test_output_schema_rowtypeinfo_round_trip():
+ mem = _memory()
+ event_id = uuid4()
+ schema = OutputSchema(
+ output_schema=RowTypeInfo(
+ [BasicTypeInfo.INT_TYPE_INFO()],
+ ["result"],
+ )
+ )
+ _save_tool_request_event_context(
+ mem, event_id, uuid4(), "ollama", None, schema
+ )
+ context = _get_tool_request_event_context(mem, event_id)
+ assert isinstance(context["output_schema"], OutputSchema)
+ assert context["output_schema"].output_schema.get_field_names() ==
["result"]
+
+
+def test_save_get_preserves_model_and_prompt_args():
+ mem = _memory()
+ event_id = uuid4()
+ prompt_args = {"a": 1, "b": "x"}
+ _save_tool_request_event_context(
+ mem, event_id, uuid4(), "ollama", prompt_args, None
+ )
+ context = _get_tool_request_event_context(mem, event_id)
+ assert context["model"] == "ollama"
+ assert context["prompt_args"] == prompt_args
diff --git
a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
index e1f04f62..e9e799da 100644
--- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
@@ -311,7 +311,7 @@ class TestProcessToolResponsePromptArgsForwarding:
"_TOOL_REQUEST_EVENT_CONTEXT",
{
str(tool_request_event_id): {
- "initial_request_id": initial_request_id,
+ "initial_request_id": str(initial_request_id),
"model": "test-model",
"prompt_args": saved_prompt_args,
"output_schema": None,
@@ -325,7 +325,9 @@ class TestProcessToolResponsePromptArgsForwarding:
"_TOOL_CALL_CONTEXT",
{
str(initial_request_id): [
- ChatMessage(role=MessageRole.USER, content="hi")
+ ChatMessage(
+ role=MessageRole.USER, content="hi"
+ ).model_dump(mode="json")
]
},
)