This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 60059bb9adfab0f3d887252d8bd0e62f0b1d2e07
Author: Xu Huang <zuosi...@alibaba-inc.com>
AuthorDate: Fri Jul 25 17:28:50 2025 +0800

    [runtime] Check memory access in task mailbox thread
---
 .../agents/runtime/context/RunnerContextImpl.java  | 12 ++--
 .../agents/runtime/memory/MemoryObjectImpl.java    | 16 +++++
 .../runtime/operator/ActionExecutionOperator.java  | 38 +++++++++++-
 .../operator/ActionExecutionOperatorFactory.java   |  5 +-
 .../python/context/PythonRunnerContextImpl.java    | 11 +---
 .../runtime/python/utils/PythonActionExecutor.java |  6 +-
 .../flink/agents/runtime/CompileUtilsTest.java     |  2 +-
 .../operator/ActionExecutionOperatorTest.java      | 70 ++++++++++++++++++----
 8 files changed, 129 insertions(+), 31 deletions(-)

diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
index 75e2309..bef2775 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java
@@ -38,16 +38,17 @@ public class RunnerContextImpl implements RunnerContext {
 
     protected final List<Event> pendingEvents = new ArrayList<>();
     protected final MapState<String, MemoryObjectImpl.MemoryItem> store;
-
     protected final FlinkAgentsMetricGroupImpl agentMetricGroup;
-
+    protected final Runnable mailboxThreadChecker;
     protected String actionName;
 
     public RunnerContextImpl(
             MapState<String, MemoryObjectImpl.MemoryItem> store,
-            FlinkAgentsMetricGroupImpl agentMetricGroup) {
+            FlinkAgentsMetricGroupImpl agentMetricGroup,
+            Runnable mailboxThreadChecker) {
         this.store = store;
         this.agentMetricGroup = agentMetricGroup;
+        this.mailboxThreadChecker = mailboxThreadChecker;
     }
 
     public void setActionName(String actionName) {
@@ -66,6 +67,7 @@ public class RunnerContextImpl implements RunnerContext {
 
     @Override
     public void sendEvent(Event event) {
+        mailboxThreadChecker.run();
         try {
             JsonUtils.checkSerializable(event);
         } catch (JsonProcessingException e) {
@@ -77,6 +79,7 @@ public class RunnerContextImpl implements RunnerContext {
     }
 
     public List<Event> drainEvents() {
+        mailboxThreadChecker.run();
         List<Event> list = new ArrayList<>(this.pendingEvents);
         this.pendingEvents.clear();
         return list;
@@ -89,6 +92,7 @@ public class RunnerContextImpl implements RunnerContext {
 
     @Override
     public MemoryObject getShortTermMemory() throws Exception {
-        return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY);
+        mailboxThreadChecker.run();
+        return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY, 
mailboxThreadChecker);
     }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
index f4508f4..dfedf1f 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java
@@ -34,10 +34,18 @@ public class MemoryObjectImpl implements MemoryObject {
 
     private final MapState<String, MemoryItem> store;
     private final String prefix;
+    private final Runnable mailboxThreadChecker;
 
     public MemoryObjectImpl(MapState<String, MemoryItem> store, String prefix) 
throws Exception {
+        this(store, prefix, () -> {});
+    }
+
+    public MemoryObjectImpl(
+            MapState<String, MemoryItem> store, String prefix, Runnable 
mailboxThreadChecker)
+            throws Exception {
         this.store = store;
         this.prefix = prefix;
+        this.mailboxThreadChecker = mailboxThreadChecker;
         if (!store.contains(ROOT_KEY)) {
             store.put(ROOT_KEY, new MemoryItem());
         }
@@ -45,6 +53,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public MemoryObject get(String path) throws Exception {
+        mailboxThreadChecker.run();
         String absPath = fullPath(path);
         if (store.contains(absPath)) {
             return new MemoryObjectImpl(store, absPath);
@@ -54,6 +63,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public void set(String path, Object value) throws Exception {
+        mailboxThreadChecker.run();
         String absPath = fullPath(path);
         String[] parts = absPath.split("\\.");
         fillParents(parts);
@@ -77,6 +87,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public MemoryObject newObject(String path, boolean overwrite) throws 
Exception {
+        mailboxThreadChecker.run();
         String absPath = fullPath(path);
         String[] parts = absPath.split("\\.");
 
@@ -108,6 +119,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public boolean isExist(String path) {
+        mailboxThreadChecker.run();
         try {
             return store.contains(fullPath(path));
         } catch (Exception e) {
@@ -117,6 +129,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public List<String> getFieldNames() throws Exception {
+        mailboxThreadChecker.run();
         MemoryItem memItem = store.get(prefix);
         if (memItem != null && memItem.getType() == ItemType.OBJECT) {
             return new ArrayList<>(memItem.getSubKeys());
@@ -126,6 +139,7 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public Map<String, Object> getFields() throws Exception {
+        mailboxThreadChecker.run();
         Map<String, Object> result = new HashMap<>();
         for (String name : getFieldNames()) {
             String absPath = fullPath(name);
@@ -141,12 +155,14 @@ public class MemoryObjectImpl implements MemoryObject {
 
     @Override
     public boolean isNestedObject() throws Exception {
+        mailboxThreadChecker.run();
         MemoryItem memItem = store.get(prefix);
         return memItem != null && memItem.getType() == ItemType.OBJECT;
     }
 
     @Override
     public Object getValue() throws Exception {
+        mailboxThreadChecker.run();
         MemoryItem memItem = store.get(prefix);
         if (memItem != null && memItem.getType() == ItemType.VALUE) {
             return memItem.getValue();
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index cbc7e10..9eef4ff 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -32,6 +32,7 @@ import 
org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
 import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 import org.apache.flink.agents.runtime.utils.EventUtil;
+import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.api.common.state.*;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -42,10 +43,13 @@ import 
org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl;
+import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor;
 import org.apache.flink.types.Row;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.lang.reflect.Field;
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
@@ -87,12 +91,25 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient BuiltInMetrics builtInMetrics;
 
+    private transient MailboxExecutor mailboxExecutor;
+
+    // We need to check whether the current thread is the mailbox thread using 
the mailbox
+    // processor.
+    // TODO: This is a temporary workaround. In the future, we should add an 
interface in
+    // MailboxExecutor to check whether a thread is a mailbox thread, rather 
than using reflection
+    // to obtain the MailboxProcessor instance and make the determination.
+    private transient MailboxProcessor mailboxProcessor;
+
     public ActionExecutionOperator(
-            AgentPlan agentPlan, Boolean inputIsJava, ProcessingTimeService 
processingTimeService) {
+            AgentPlan agentPlan,
+            Boolean inputIsJava,
+            ProcessingTimeService processingTimeService,
+            MailboxExecutor mailboxExecutor) {
         this.agentPlan = agentPlan;
         this.inputIsJava = inputIsJava;
         this.processingTimeService = processingTimeService;
         this.chainingStrategy = ChainingStrategy.ALWAYS;
+        this.mailboxExecutor = mailboxExecutor;
     }
 
     @Override
@@ -112,10 +129,13 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup());
         builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan);
 
-        runnerContext = new RunnerContextImpl(shortTermMemState, metricGroup);
+        runnerContext =
+                new RunnerContextImpl(shortTermMemState, metricGroup, 
this::checkMailboxThread);
 
         // init PythonActionExecutor
         initPythonActionExecutor();
+
+        mailboxProcessor = getMailboxProcessor();
     }
 
     @Override
@@ -207,7 +227,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                     new PythonActionExecutor(
                             pythonEnvironmentManager,
                             new ObjectMapper().writeValueAsString(agentPlan));
-            pythonActionExecutor.open(shortTermMemState, metricGroup);
+            pythonActionExecutor.open(shortTermMemState, metricGroup, 
this::checkMailboxThread);
         }
     }
 
@@ -240,4 +260,16 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
             return agentPlan.getActionsTriggeredBy(event.getClass().getName());
         }
     }
+
+    private MailboxProcessor getMailboxProcessor() throws Exception {
+        Field field = 
MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor");
+        field.setAccessible(true);
+        return (MailboxProcessor) field.get(mailboxExecutor);
+    }
+
+    private void checkMailboxThread() {
+        checkState(
+                mailboxProcessor.isMailboxThread(),
+                "Expected to be running on the task mailbox thread, but was 
not.");
+    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java
index 213361d..92b70a5 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java
@@ -41,7 +41,10 @@ public class ActionExecutionOperatorFactory<IN, OUT>
             StreamOperatorParameters<OUT> parameters) {
         ActionExecutionOperator<IN, OUT> op =
                 new ActionExecutionOperator<>(
-                        agentPlan, inputIsJava, 
parameters.getProcessingTimeService());
+                        agentPlan,
+                        inputIsJava,
+                        parameters.getProcessingTimeService(),
+                        parameters.getMailboxExecutor());
         op.setup(
                 parameters.getContainingTask(),
                 parameters.getStreamConfig(),
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
index d9a20fc..b3377cb 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java
@@ -18,7 +18,6 @@
 package org.apache.flink.agents.runtime.python.context;
 
 import org.apache.flink.agents.api.Event;
-import org.apache.flink.agents.api.context.MemoryObject;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
 import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
@@ -35,8 +34,9 @@ public class PythonRunnerContextImpl extends 
RunnerContextImpl {
 
     public PythonRunnerContextImpl(
             MapState<String, MemoryObjectImpl.MemoryItem> store,
-            FlinkAgentsMetricGroupImpl agentMetricGroup) {
-        super(store, agentMetricGroup);
+            FlinkAgentsMetricGroupImpl agentMetricGroup,
+            Runnable mailboxThreadChecker) {
+        super(store, agentMetricGroup, mailboxThreadChecker);
     }
 
     @Override
@@ -50,9 +50,4 @@ public class PythonRunnerContextImpl extends 
RunnerContextImpl {
         // this method will be invoked by PythonActionExecutor's python 
interpreter.
         sendEvent(new PythonEvent(event, type));
     }
-
-    @Override
-    public MemoryObject getShortTermMemory() throws Exception {
-        return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY);
-    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
index 905ddb6..ebd8c05 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
@@ -62,7 +62,8 @@ public class PythonActionExecutor {
 
     public void open(
             MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState,
-            FlinkAgentsMetricGroupImpl metricGroup)
+            FlinkAgentsMetricGroupImpl metricGroup,
+            Runnable mailboxThreadChecker)
             throws Exception {
         environmentManager.open();
         EmbeddedPythonEnvironment env = environmentManager.createEnvironment();
@@ -70,7 +71,8 @@ public class PythonActionExecutor {
         interpreter = env.getInterpreter();
         interpreter.exec(PYTHON_IMPORTS);
 
-        runnerContext = new PythonRunnerContextImpl(shortTermMemState, 
metricGroup);
+        runnerContext =
+                new PythonRunnerContextImpl(shortTermMemState, metricGroup, 
mailboxThreadChecker);
 
         // TODO: remove the set and get runner context after updating pemja to 
version 0.5.3
         Object pythonRunnerContextObject =
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java 
b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
index e064de1..1617a19 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
@@ -41,7 +41,7 @@ public class CompileUtilsTest {
     private static final Long TEST_SEQUENCE_END = 100L;
     // Agent logic: x -> (x + 1) * 2
     private static final AgentPlan TEST_AGENT_PLAN =
-            ActionExecutionOperatorTest.TestAgent.getAgentPlan();
+            ActionExecutionOperatorTest.TestAgent.getAgentPlan(false);
 
     @Test
     void testJavaNoKeyedStreamConnectToAgent() throws Exception {
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index 7987bcd..2559bab 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -28,7 +28,6 @@ import org.apache.flink.agents.plan.JavaFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.util.ExceptionUtils;
 import org.junit.jupiter.api.Test;
@@ -37,20 +36,21 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for {@link ActionExecutionOperator}. */
 public class ActionExecutionOperatorTest {
 
     @Test
     void testExecuteAgent() throws Exception {
-        ActionExecutionOperator<Long, Object> operator =
-                new ActionExecutionOperator<>(
-                        TestAgent.getAgentPlan(), true, new 
TestProcessingTimeService());
         try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
                 new KeyedOneInputStreamOperatorTestHarness<>(
-                        operator,
+                        new 
ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true),
                         (KeySelector<Long, Long>) value -> value,
                         TypeInformation.of(Long.class))) {
             testHarness.open();
@@ -67,6 +67,21 @@ public class ActionExecutionOperatorTest {
         }
     }
 
+    @Test
+    void testMemoryAccessProhibitedOutsideMailboxThread() throws Exception {
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new 
ActionExecutionOperatorFactory(TestAgent.getAgentPlan(true), true),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+
+            assertThatThrownBy(() -> testHarness.processElement(new 
StreamRecord<>(0L)))
+                    .rootCause()
+                    .hasMessageContaining("Expected to be running on the task 
mailbox thread");
+        }
+    }
+
     public static class TestAgent {
 
         public static class MiddleEvent extends Event {
@@ -82,7 +97,7 @@ public class ActionExecutionOperatorTest {
             }
         }
 
-        public static void processInputEvent(InputEvent event, RunnerContext 
context) {
+        public static void action1(InputEvent event, RunnerContext context) {
             Long inputData = (Long) event.getInput();
             try {
                 MemoryObject mem = context.getShortTermMemory();
@@ -93,7 +108,7 @@ public class ActionExecutionOperatorTest {
             context.sendEvent(new MiddleEvent(inputData + 1));
         }
 
-        public static void processMiddleEvent(MiddleEvent event, RunnerContext 
context) {
+        public static void action2(MiddleEvent event, RunnerContext context) {
             try {
                 MemoryObject mem = context.getShortTermMemory();
                 Long tmp = (Long) mem.get("tmp").getValue();
@@ -103,23 +118,37 @@ public class ActionExecutionOperatorTest {
             }
         }
 
-        public static AgentPlan getAgentPlan() {
+        public static void action3(MiddleEvent event, RunnerContext context) {
+            // To test disallows memory access from non-mailbox threads.
+            try {
+                ExecutorService executor = Executors.newSingleThreadExecutor();
+                Future<Long> future =
+                        executor.submit(
+                                () -> (Long) 
context.getShortTermMemory().get("tmp").getValue());
+                Long tmp = future.get();
+                context.sendEvent(new OutputEvent(tmp * 2));
+            } catch (Exception e) {
+                ExceptionUtils.rethrow(e);
+            }
+        }
+
+        public static AgentPlan getAgentPlan(boolean 
testMemoryAccessOutOfMailbox) {
             try {
                 Map<String, List<Action>> actionsByEvent = new HashMap<>();
                 Action action1 =
                         new Action(
-                                "processInputEvent",
+                                "action1",
                                 new JavaFunction(
                                         TestAgent.class,
-                                        "processInputEvent",
+                                        "action1",
                                         new Class<?>[] {InputEvent.class, 
RunnerContext.class}),
                                 
Collections.singletonList(InputEvent.class.getName()));
                 Action action2 =
                         new Action(
-                                "processMiddleEvent",
+                                "action2",
                                 new JavaFunction(
                                         TestAgent.class,
-                                        "processMiddleEvent",
+                                        "action2",
                                         new Class<?>[] {MiddleEvent.class, 
RunnerContext.class}),
                                 
Collections.singletonList(MiddleEvent.class.getName()));
                 actionsByEvent.put(InputEvent.class.getName(), 
Collections.singletonList(action1));
@@ -127,6 +156,23 @@ public class ActionExecutionOperatorTest {
                 Map<String, Action> actions = new HashMap<>();
                 actions.put(action1.getName(), action1);
                 actions.put(action2.getName(), action2);
+
+                if (testMemoryAccessOutOfMailbox) {
+                    Action action3 =
+                            new Action(
+                                    "action3",
+                                    new JavaFunction(
+                                            TestAgent.class,
+                                            "action3",
+                                            new Class<?>[] {
+                                                MiddleEvent.class, 
RunnerContext.class
+                                            }),
+                                    
Collections.singletonList(MiddleEvent.class.getName()));
+                    actionsByEvent.put(
+                            MiddleEvent.class.getName(), 
Collections.singletonList(action3));
+                    actions.put(action3.getName(), action3);
+                }
+
                 return new AgentPlan(actions, actionsByEvent);
             } catch (Exception e) {
                 ExceptionUtils.rethrow(e);

Reply via email to