addu390 commented on code in PR #709:
URL: https://github.com/apache/flink-agents/pull/709#discussion_r3318598273


##########
api/src/test/java/org/apache/flink/agents/api/CrossLanguageEventSnapshotTest.java:
##########
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.flink.agents.api.agents.OutputSchema;
+import org.apache.flink.agents.api.chat.messages.ChatMessage;
+import org.apache.flink.agents.api.chat.messages.MessageRole;
+import org.apache.flink.agents.api.event.ChatRequestEvent;
+import org.apache.flink.agents.api.event.ChatResponseEvent;
+import org.apache.flink.agents.api.event.ContextRetrievalRequestEvent;
+import org.apache.flink.agents.api.event.ContextRetrievalResponseEvent;
+import org.apache.flink.agents.api.event.ToolRequestEvent;
+import org.apache.flink.agents.api.event.ToolResponseEvent;
+import org.apache.flink.agents.api.tools.ToolResponse;
+import org.apache.flink.agents.api.vectorstores.Document;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assumptions.assumeTrue;
+
+/** Cross-language event SerDe snapshot tests. */
+class CrossLanguageEventSnapshotTest {
+
+    private static final ObjectMapper MAPPER = new ObjectMapper();
+
+    private static final UUID FIXED_EVENT_ID =
+            UUID.fromString("00000000-0000-0000-0000-000000000001");
+    private static final UUID FIXED_REQUEST_ID =
+            UUID.fromString("00000000-0000-0000-0000-000000000002");
+    private static final String FIXED_TOOL_CALL_ID = "call_aaaa";
+    private static final long FIXED_TIMESTAMP = 1_700_000_000_000L;
+
+    private static Path snapshotDir;
+
+    @BeforeAll
+    static void resolveSnapshotDir() {
+        Path repoRoot = Paths.get(System.getProperty("user.dir")).getParent();
+        snapshotDir = 
repoRoot.resolve("e2e-test/cross-language-event-snapshots");
+    }
+
+    // ── Helpers ────────────────────────────────────────────────────────────
+
+    private static boolean regenerateRequested() {
+        return Boolean.parseBoolean(System.getProperty("regenerate.snapshots", 
"false"));
+    }
+
+    private static void writeJavaSnapshot(String fileName, Event event) throws 
Exception {
+        String json = 
MAPPER.writerWithDefaultPrettyPrinter().writeValueAsString(event);
+        Path target = snapshotDir.resolve("java/" + fileName);
+        Files.createDirectories(target.getParent());
+        Files.writeString(target, json + "\n");
+    }
+
+    private static void assertJavaSnapshotStable(String fileName, Event event) 
throws Exception {
+        String actualJson = MAPPER.writeValueAsString(event);
+        JsonNode actual = MAPPER.readTree(actualJson);
+
+        Path committed = snapshotDir.resolve("java/" + fileName);
+        assertTrue(
+                Files.exists(committed),
+                "Java snapshot "
+                        + fileName
+                        + " missing from "
+                        + committed
+                        + ". If you added a new event, regenerate with 
-Dregenerate.snapshots=true and commit alongside the test.");
+        JsonNode expected = MAPPER.readTree(Files.readString(committed));
+
+        assertEquals(
+                expected,
+                actual,
+                "Java serialization of "
+                        + fileName
+                        + " drifted from committed snapshot; if intentional, 
regenerate.");
+    }
+
+    private static Event readPythonSnapshot(String fileName) throws Exception {
+        Path pythonSnapshot = snapshotDir.resolve("python/" + fileName);
+        assertTrue(
+                Files.exists(pythonSnapshot),
+                "Python snapshot "
+                        + fileName
+                        + " missing from "
+                        + pythonSnapshot
+                        + ". Regenerate the Python side with 
REGENERATE_SNAPSHOTS=1 and commit alongside this test.");
+        return Event.fromJson(Files.readString(pythonSnapshot));
+    }
+
+    // ── InputEvent ─────────────────────────────────────────────────────────
+
+    private static InputEvent buildInputEvent() {
+        Map<String, Object> attrs = new HashMap<>();
+        attrs.put("input", "hello");
+        return new InputEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateInputEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("input_event.json", buildInputEvent());
+    }
+
+    @Test
+    void inputEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("input_event.json", buildInputEvent());
+    }
+
+    @Test
+    void javaCanDeserializeInputEventFromPythonSnapshot() throws Exception {
+        Event base = readPythonSnapshot("input_event.json");
+        InputEvent typed = InputEvent.fromEvent(base);
+
+        assertEquals(
+                FIXED_EVENT_ID, typed.getId(), "ID lost when deserializing 
Python InputEvent.");
+        assertEquals(InputEvent.EVENT_TYPE, typed.getType());
+        assertEquals("hello", typed.getInput(), "InputEvent.input mismatch.");
+    }
+
+    // ── OutputEvent ────────────────────────────────────────────────────────
+
+    private static OutputEvent buildOutputEvent() {
+        Map<String, Object> attrs = new HashMap<>();
+        attrs.put("output", "world");
+        return new OutputEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateOutputEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("output_event.json", buildOutputEvent());
+    }
+
+    @Test
+    void outputEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("output_event.json", buildOutputEvent());
+    }
+
+    @Test
+    void javaCanDeserializeOutputEventFromPythonSnapshot() throws Exception {
+        Event base = readPythonSnapshot("output_event.json");
+        OutputEvent typed = OutputEvent.fromEvent(base);
+
+        assertEquals(
+                FIXED_EVENT_ID, typed.getId(), "ID lost when deserializing 
Python OutputEvent.");
+        assertEquals(OutputEvent.EVENT_TYPE, typed.getType());
+        assertEquals("world", typed.getOutput(), "OutputEvent.output 
mismatch.");
+    }
+
+    // ── ChatRequestEvent ───────────────────────────────────────────────────
+
+    private static ChatRequestEvent buildChatRequestEvent() {
+        Map<String, Object> attrs = new LinkedHashMap<>();
+        attrs.put("model", "test-model");
+        attrs.put("messages", List.of(new ChatMessage(MessageRole.USER, "hello 
world")));
+        return new ChatRequestEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateChatRequestEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("chat_request_event.json", buildChatRequestEvent());
+    }
+
+    @Test
+    void chatRequestEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("chat_request_event.json", 
buildChatRequestEvent());
+    }
+
+    @Test
+    void javaCanDeserializeChatRequestEventFromPythonSnapshot() throws 
Exception {
+        Event base = readPythonSnapshot("chat_request_event.json");
+        ChatRequestEvent typed = ChatRequestEvent.fromEvent(base);
+
+        assertEquals(FIXED_EVENT_ID, typed.getId());
+        assertEquals(ChatRequestEvent.EVENT_TYPE, typed.getType());
+        assertEquals("test-model", typed.getModel());
+        assertNotNull(typed.getMessages());
+        assertEquals(1, typed.getMessages().size(), "Expected one message.");
+        ChatMessage msg = typed.getMessages().get(0);
+        assertEquals(MessageRole.USER, msg.getRole(), "Role mismatch on 
Python-produced message.");
+        assertEquals("hello world", msg.getContent());
+    }
+
+    @Test
+    void chatRequestOutputSchemaWireFormatIsJavaShaped() throws Exception {
+        OutputSchema schema =
+                new OutputSchema(
+                        new RowTypeInfo(
+                                new TypeInformation[] 
{BasicTypeInfo.STRING_TYPE_INFO},
+                                new String[] {"name"}));
+        ChatRequestEvent event =
+                new ChatRequestEvent(
+                        "test-model", List.of(new 
ChatMessage(MessageRole.USER, "hi")), schema);
+        String json = MAPPER.writeValueAsString(event);
+
+        assertTrue(json.contains("\"fieldNames\""), "Java wire format uses 
`fieldNames`.");
+        assertFalse(json.contains("\"names\""), "Java wire format does not use 
Python's `names`.");
+    }
+
+    // ── ChatResponseEvent ──────────────────────────────────────────────────
+
+    private static ChatResponseEvent buildChatResponseEvent() {
+        Map<String, Object> attrs = new LinkedHashMap<>();
+        attrs.put("request_id", FIXED_REQUEST_ID);
+        attrs.put("response", new ChatMessage(MessageRole.ASSISTANT, "hi 
there"));
+        attrs.put("retry_count", 0);
+        attrs.put("total_retry_wait_sec", 0);
+        return new ChatResponseEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateChatResponseEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("chat_response_event.json", 
buildChatResponseEvent());
+    }
+
+    @Test
+    void chatResponseEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("chat_response_event.json", 
buildChatResponseEvent());
+    }
+
+    @Test
+    void javaCanDeserializeChatResponseEventFromPythonSnapshot() throws 
Exception {
+        Event base = readPythonSnapshot("chat_response_event.json");
+        ChatResponseEvent typed = ChatResponseEvent.fromEvent(base);
+
+        assertEquals(FIXED_EVENT_ID, typed.getId());
+        assertEquals(ChatResponseEvent.EVENT_TYPE, typed.getType());
+        assertEquals(FIXED_REQUEST_ID, typed.getRequestId(), "request_id 
mismatch.");
+        ChatMessage response = typed.getResponse();
+        assertNotNull(response, "response field is null.");
+        assertEquals(MessageRole.ASSISTANT, response.getRole(), "Role mismatch 
on response.");
+        assertEquals("hi there", response.getContent());
+    }
+
+    // ── ToolRequestEvent ───────────────────────────────────────────────────
+
+    private static ToolRequestEvent buildToolRequestEvent() {
+        Map<String, Object> toolCall = new LinkedHashMap<>();
+        toolCall.put("id", FIXED_TOOL_CALL_ID);
+        toolCall.put("name", "echo");
+        toolCall.put("arguments", Map.of("value", "ping"));
+
+        Map<String, Object> attrs = new LinkedHashMap<>();
+        attrs.put("model", "test-model");
+        attrs.put("tool_calls", List.of(toolCall));
+        return new ToolRequestEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateToolRequestEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("tool_request_event.json", buildToolRequestEvent());
+    }
+
+    @Test
+    void toolRequestEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("tool_request_event.json", 
buildToolRequestEvent());
+    }
+
+    @Test
+    void javaCanDeserializeToolRequestEventFromPythonSnapshot() throws 
Exception {
+        Event base = readPythonSnapshot("tool_request_event.json");
+        ToolRequestEvent typed = ToolRequestEvent.fromEvent(base);
+
+        assertEquals(FIXED_EVENT_ID, typed.getId());
+        assertEquals(ToolRequestEvent.EVENT_TYPE, typed.getType());
+        assertEquals("test-model", typed.getModel());
+        List<Map<String, Object>> toolCalls = typed.getToolCalls();
+        assertNotNull(toolCalls);
+        assertEquals(1, toolCalls.size());
+        assertEquals(FIXED_TOOL_CALL_ID, toolCalls.get(0).get("id"));
+    }
+
+    // ── ToolResponseEvent ──────────────────────────────────────────────────
+
+    private static ToolResponseEvent buildToolResponseEvent() {
+        Map<String, Object> attrs = new LinkedHashMap<>();
+        attrs.put("request_id", FIXED_REQUEST_ID);
+        attrs.put("responses", Map.of(FIXED_TOOL_CALL_ID, 
ToolResponse.success("pong")));
+        attrs.put("success", Map.of(FIXED_TOOL_CALL_ID, true));
+        attrs.put("error", new HashMap<String, String>());
+        attrs.put("external_ids", new HashMap<String, String>());
+        attrs.put("timestamp", FIXED_TIMESTAMP);
+        return new ToolResponseEvent(FIXED_EVENT_ID, attrs);
+    }
+
+    @Test
+    void regenerateToolResponseEventJavaSnapshot() throws Exception {
+        assumeTrue(regenerateRequested(), "Set -Dregenerate.snapshots=true to 
refresh.");
+        writeJavaSnapshot("tool_response_event.json", 
buildToolResponseEvent());
+    }
+
+    @Test
+    void toolResponseEventJavaSnapshotIsStable() throws Exception {
+        assertJavaSnapshotStable("tool_response_event.json", 
buildToolResponseEvent());
+    }
+
+    @Test
+    void pythonToolResponseEventLosesDataWhenConsumedByJava() throws Exception 
{

Review Comment:
   Fixed. `ToolResponseEvent.fromEvent` now wraps non-`ToolResponse`/non-Map 
values via `ToolResponse.success(v)`, so covers primitives 
`string/number/bool..` round-trip. Renamed the test to infer that and extended 
the Python snapshot with numeric and boolean entries to pin all three.
   
   Root cause: `responses` is `Dict[UUID, Any]` on Python but `Map<String, 
ToolResponse>` on Java, so raw scalars satisfied Python's schema with nowhere 
to land on Java. Tightening Python's type to `Dict[UUID, ToolResponse]` would 
be the cleaner long-term fix/follow-up.
   
   One caveat: when Python encodes an error case as a `string` in responses 
(e.g. `tool_call_action.py:48`), Java now wraps it as `success(error_string)` 
because the Python wire shape doesn't distinguish success from error. The 
cleaner fix would be tightening Python's responses in a follow-up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to