This is an automated email from the ASF dual-hosted git repository.

wenjin272 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git


The following commit(s) were added to refs/heads/main by this push:
     new 06e7722a [python] Normalize built-in tool-call context for checkpoint 
safety (#828)
06e7722a is described below

commit 06e7722aebfc33902e9c26dea4ecf51680ae20bc
Author: Weiqing Yang <[email protected]>
AuthorDate: Wed Jun 10 03:08:45 2026 -0700

    [python] Normalize built-in tool-call context for checkpoint safety (#828)
    
    The built-in chat-model action stored non-primitive Python objects in
    sensory memory: UUID values, an OutputSchema, and ChatMessage lists.
    Pemja wraps such objects as PyObject holders whose JNI pointers go stale
    after a TaskManager/Python restart, so restoring the checkpointed tool
    context crashes in JcpPyObject_FromJObject.
    
    Normalize these values to a primitive-only form before they reach memory
    and reconstruct the rich types on read, fully inside the three
    tool-context helpers (no caller or signature changes):
    
    - ChatMessage lists are stored via model_dump(mode="json") and
      reconstructed via ChatMessage.model_validate.
    - initial_request_id is stored as str and reconstructed to UUID.
    - output_schema is stored via OutputSchema.model_dump() and
      reconstructed via OutputSchema.model_validate.
    
    Dict keys were already strings and the retry-stats context already holds
    only ints, so both are unchanged. prompt_args is user-supplied and
    already round-trips as a ChatRequestEvent attribute, so it is left as-is.
---
 .../flink_agents/plan/actions/chat_model_action.py |  33 ++++-
 .../plan/tests/actions/test_chat_model_action.py   | 157 ++++++++++++++++++++-
 .../tests/actions/test_chat_model_action_retry.py  |   6 +-
 3 files changed, 187 insertions(+), 9 deletions(-)

diff --git a/python/flink_agents/plan/actions/chat_model_action.py 
b/python/flink_agents/plan/actions/chat_model_action.py
index 188eb9e7..e4572056 100644
--- a/python/flink_agents/plan/actions/chat_model_action.py
+++ b/python/flink_agents/plan/actions/chat_model_action.py
@@ -15,7 +15,6 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-import copy
 import json
 import logging
 import re
@@ -59,6 +58,16 @@ _logger = logging.getLogger(__name__)
 # ============================================================================
 # Helper Functions for Tool Call Context Management
 # ============================================================================
+def _serialize_messages(messages: List[ChatMessage]) -> List[Dict]:
+    """Materialize chat messages into JSON-safe dicts for checkpoint-stable 
memory."""
+    return [message.model_dump(mode="json") for message in messages]
+
+
+def _deserialize_messages(messages: List[Dict]) -> List[ChatMessage]:
+    """Reconstruct chat messages from their stored JSON-safe dict form."""
+    return [ChatMessage.model_validate(message) for message in messages]
+
+
 def _update_tool_call_context(
     sensory_memory: MemoryObject,
     initial_request_id: UUID,
@@ -81,12 +90,12 @@ def _update_tool_call_context(
     key = str(initial_request_id)
     tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
     if key not in tool_call_context and initial_messages is not None:
-        tool_call_context[key] = copy.deepcopy(initial_messages)
+        tool_call_context[key] = _serialize_messages(initial_messages)
 
-    tool_call_context[key].extend(added_messages)
+    tool_call_context[key].extend(_serialize_messages(added_messages))
 
     sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
-    return tool_call_context[key]
+    return _deserialize_messages(tool_call_context[key])
 
 
 def _save_tool_request_event_context(
@@ -100,10 +109,12 @@ def _save_tool_request_event_context(
     """Save the context for a specific tool request event."""
     context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
     context[str(tool_request_event_id)] = {
-        "initial_request_id": initial_request_id,
+        "initial_request_id": str(initial_request_id),
         "model": model,
         _PROMPT_ARGS: prompt_args if prompt_args is not None else {},
-        "output_schema": output_schema,
+        "output_schema": output_schema.model_dump()
+        if output_schema is not None
+        else None,
     }
     sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context)
 
@@ -114,6 +125,16 @@ def _get_tool_request_event_context(
     """Get and remove the context for a specific tool request event."""
     context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
     removed_context = context.pop(str(request_id), {})
+    if removed_context:
+        removed_context["initial_request_id"] = UUID(
+            removed_context["initial_request_id"]
+        )
+        output_schema = removed_context["output_schema"]
+        removed_context["output_schema"] = (
+            OutputSchema.model_validate(output_schema)
+            if output_schema is not None
+            else None
+        )
     return removed_context
 
 
diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action.py 
b/python/flink_agents/plan/tests/actions/test_chat_model_action.py
index 72b223d2..94ff8ab9 100644
--- a/python/flink_agents/plan/tests/actions/test_chat_model_action.py
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action.py
@@ -15,7 +15,49 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 
#################################################################################
-from flink_agents.plan.actions.chat_model_action import _clean_llm_response
+from uuid import uuid4
+
+from pydantic import BaseModel
+from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo
+
+from flink_agents.api.agents.react_agent import OutputSchema
+from flink_agents.api.chat_message import ChatMessage, MessageRole
+from flink_agents.api.memory_object import MemoryType
+from flink_agents.plan.actions.chat_model_action import (
+    _TOOL_CALL_CONTEXT,
+    _TOOL_REQUEST_EVENT_CONTEXT,
+    _clean_llm_response,
+    _get_tool_request_event_context,
+    _save_tool_request_event_context,
+    _update_tool_call_context,
+)
+from flink_agents.runtime.local_memory_object import LocalMemoryObject
+
+
+def _memory() -> LocalMemoryObject:
+    return LocalMemoryObject(MemoryType.SHORT_TERM, {})
+
+
+def _assert_primitive(obj) -> None:
+    if obj is None or isinstance(obj, bool | int | float | str | bytes):
+        return
+    if isinstance(obj, list):
+        for item in obj:
+            _assert_primitive(item)
+        return
+    if isinstance(obj, dict):
+        for k, v in obj.items():
+            assert isinstance(k, str | int | float | bool), (
+                f"non-primitive key: {k!r}"
+            )
+            _assert_primitive(v)
+        return
+    msg = f"non-primitive value of type {type(obj).__name__}: {obj!r}"
+    raise AssertionError(msg)
+
+
+class _Result(BaseModel):
+    result: int
 
 
 def test_clean_llm_response_with_json_block():
@@ -52,3 +94,116 @@ def test_clean_llm_response_with_multiple_lines_in_block():
     input_str = '```json\n{\n  "key": "value"\n}\n```'
     expected = '{\n  "key": "value"\n}'
     assert _clean_llm_response(input_str) == expected
+
+
+def test_update_tool_call_context_stores_primitive_only():
+    mem = _memory()
+    initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+    added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
+    _update_tool_call_context(mem, uuid4(), initial, added)
+    _assert_primitive(mem.get(_TOOL_CALL_CONTEXT))
+
+
+def test_update_tool_call_context_returns_chat_messages():
+    mem = _memory()
+    initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+    added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
+    result = _update_tool_call_context(mem, uuid4(), initial, added)
+    assert all(isinstance(message, ChatMessage) for message in result)
+    assert [(m.role, m.content) for m in result] == [
+        (MessageRole.USER, "hi"),
+        (MessageRole.ASSISTANT, "hello"),
+    ]
+
+
+def test_tool_request_event_context_stores_primitive_only():
+    mem = _memory()
+    _save_tool_request_event_context(
+        mem,
+        uuid4(),
+        uuid4(),
+        "ollama",
+        {"k": "v"},
+        OutputSchema(output_schema=_Result),
+    )
+    _assert_primitive(mem.get(_TOOL_REQUEST_EVENT_CONTEXT))
+
+
+def test_tool_request_event_context_round_trip():
+    mem = _memory()
+    event_id = uuid4()
+    initial_request_id = uuid4()
+    _save_tool_request_event_context(
+        mem,
+        event_id,
+        initial_request_id,
+        "ollama",
+        None,
+        OutputSchema(output_schema=_Result),
+    )
+    context = _get_tool_request_event_context(mem, event_id)
+    assert context["initial_request_id"] == initial_request_id
+    assert isinstance(context["initial_request_id"], type(initial_request_id))
+    assert isinstance(context["output_schema"], OutputSchema)
+    assert context["output_schema"].output_schema is _Result
+    assert context["model"] == "ollama"
+
+
+def test_get_context_none_output_schema():
+    mem = _memory()
+    event_id = uuid4()
+    _save_tool_request_event_context(mem, event_id, uuid4(), "ollama", None, 
None)
+    assert 
mem.get(_TOOL_REQUEST_EVENT_CONTEXT)[str(event_id)]["output_schema"] is None
+    context = _get_tool_request_event_context(mem, event_id)
+    assert context["output_schema"] is None
+
+
+def test_request_event_key_match_after_normalization():
+    mem = _memory()
+    event_id = uuid4()
+    _save_tool_request_event_context(
+        mem, event_id, uuid4(), "ollama", None, None
+    )
+    context = _get_tool_request_event_context(mem, event_id)
+    assert context != {}
+    assert context["model"] == "ollama"
+
+
+def test_tool_call_context_key_match_after_normalization():
+    mem = _memory()
+    request_id = uuid4()
+    initial = [ChatMessage(role=MessageRole.USER, content="hi")]
+    _update_tool_call_context(mem, request_id, initial, [])
+    extra = ChatMessage(role=MessageRole.TOOL, content="result")
+    result = _update_tool_call_context(mem, request_id, None, [extra])
+    assert len(result) == 2
+    assert len(mem.get(_TOOL_CALL_CONTEXT)[str(request_id)]) == 2
+
+
+def test_output_schema_rowtypeinfo_round_trip():
+    mem = _memory()
+    event_id = uuid4()
+    schema = OutputSchema(
+        output_schema=RowTypeInfo(
+            [BasicTypeInfo.INT_TYPE_INFO()],
+            ["result"],
+        )
+    )
+    _save_tool_request_event_context(
+        mem, event_id, uuid4(), "ollama", None, schema
+    )
+    context = _get_tool_request_event_context(mem, event_id)
+    assert isinstance(context["output_schema"], OutputSchema)
+    assert context["output_schema"].output_schema.get_field_names() == 
["result"]
+
+
+def test_save_get_preserves_model_and_prompt_args():
+    mem = _memory()
+    event_id = uuid4()
+    prompt_args = {"a": 1, "b": "x"}
+    _save_tool_request_event_context(
+        mem, event_id, uuid4(), "ollama", prompt_args, None
+    )
+    context = _get_tool_request_event_context(mem, event_id)
+    assert context["model"] == "ollama"
+    assert context["prompt_args"] == prompt_args
diff --git 
a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py 
b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
index e1f04f62..e9e799da 100644
--- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
+++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py
@@ -311,7 +311,7 @@ class TestProcessToolResponsePromptArgsForwarding:
             "_TOOL_REQUEST_EVENT_CONTEXT",
             {
                 str(tool_request_event_id): {
-                    "initial_request_id": initial_request_id,
+                    "initial_request_id": str(initial_request_id),
                     "model": "test-model",
                     "prompt_args": saved_prompt_args,
                     "output_schema": None,
@@ -325,7 +325,9 @@ class TestProcessToolResponsePromptArgsForwarding:
             "_TOOL_CALL_CONTEXT",
             {
                 str(initial_request_id): [
-                    ChatMessage(role=MessageRole.USER, content="hi")
+                    ChatMessage(
+                        role=MessageRole.USER, content="hi"
+                    ).model_dump(mode="json")
                 ]
             },
         )

Reply via email to