This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 60059bb9adfab0f3d887252d8bd0e62f0b1d2e07 Author: Xu Huang <zuosi...@alibaba-inc.com> AuthorDate: Fri Jul 25 17:28:50 2025 +0800 [runtime] Check memory access in task mailbox thread --- .../agents/runtime/context/RunnerContextImpl.java | 12 ++-- .../agents/runtime/memory/MemoryObjectImpl.java | 16 +++++ .../runtime/operator/ActionExecutionOperator.java | 38 +++++++++++- .../operator/ActionExecutionOperatorFactory.java | 5 +- .../python/context/PythonRunnerContextImpl.java | 11 +--- .../runtime/python/utils/PythonActionExecutor.java | 6 +- .../flink/agents/runtime/CompileUtilsTest.java | 2 +- .../operator/ActionExecutionOperatorTest.java | 70 ++++++++++++++++++---- 8 files changed, 129 insertions(+), 31 deletions(-) diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 75e2309..bef2775 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -38,16 +38,17 @@ public class RunnerContextImpl implements RunnerContext { protected final List<Event> pendingEvents = new ArrayList<>(); protected final MapState<String, MemoryObjectImpl.MemoryItem> store; - protected final FlinkAgentsMetricGroupImpl agentMetricGroup; - + protected final Runnable mailboxThreadChecker; protected String actionName; public RunnerContextImpl( MapState<String, MemoryObjectImpl.MemoryItem> store, - FlinkAgentsMetricGroupImpl agentMetricGroup) { + FlinkAgentsMetricGroupImpl agentMetricGroup, + Runnable mailboxThreadChecker) { this.store = store; this.agentMetricGroup = agentMetricGroup; + this.mailboxThreadChecker = mailboxThreadChecker; } public void setActionName(String actionName) { @@ -66,6 +67,7 @@ public class RunnerContextImpl implements RunnerContext { @Override public void sendEvent(Event event) { + mailboxThreadChecker.run(); try { JsonUtils.checkSerializable(event); } catch (JsonProcessingException e) { @@ -77,6 +79,7 @@ public class RunnerContextImpl implements RunnerContext { } public List<Event> drainEvents() { + mailboxThreadChecker.run(); List<Event> list = new ArrayList<>(this.pendingEvents); this.pendingEvents.clear(); return list; @@ -89,6 +92,7 @@ public class RunnerContextImpl implements RunnerContext { @Override public MemoryObject getShortTermMemory() throws Exception { - return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY); + mailboxThreadChecker.run(); + return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker); } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java index f4508f4..dfedf1f 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java @@ -34,10 +34,18 @@ public class MemoryObjectImpl implements MemoryObject { private final MapState<String, MemoryItem> store; private final String prefix; + private final Runnable mailboxThreadChecker; public MemoryObjectImpl(MapState<String, MemoryItem> store, String prefix) throws Exception { + this(store, prefix, () -> {}); + } + + public MemoryObjectImpl( + MapState<String, MemoryItem> store, String prefix, Runnable mailboxThreadChecker) + throws Exception { this.store = store; this.prefix = prefix; + this.mailboxThreadChecker = mailboxThreadChecker; if (!store.contains(ROOT_KEY)) { store.put(ROOT_KEY, new MemoryItem()); } @@ -45,6 +53,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public MemoryObject get(String path) throws Exception { + mailboxThreadChecker.run(); String absPath = fullPath(path); if (store.contains(absPath)) { return new MemoryObjectImpl(store, absPath); @@ -54,6 +63,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public void set(String path, Object value) throws Exception { + mailboxThreadChecker.run(); String absPath = fullPath(path); String[] parts = absPath.split("\\."); fillParents(parts); @@ -77,6 +87,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public MemoryObject newObject(String path, boolean overwrite) throws Exception { + mailboxThreadChecker.run(); String absPath = fullPath(path); String[] parts = absPath.split("\\."); @@ -108,6 +119,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public boolean isExist(String path) { + mailboxThreadChecker.run(); try { return store.contains(fullPath(path)); } catch (Exception e) { @@ -117,6 +129,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public List<String> getFieldNames() throws Exception { + mailboxThreadChecker.run(); MemoryItem memItem = store.get(prefix); if (memItem != null && memItem.getType() == ItemType.OBJECT) { return new ArrayList<>(memItem.getSubKeys()); @@ -126,6 +139,7 @@ public class MemoryObjectImpl implements MemoryObject { @Override public Map<String, Object> getFields() throws Exception { + mailboxThreadChecker.run(); Map<String, Object> result = new HashMap<>(); for (String name : getFieldNames()) { String absPath = fullPath(name); @@ -141,12 +155,14 @@ public class MemoryObjectImpl implements MemoryObject { @Override public boolean isNestedObject() throws Exception { + mailboxThreadChecker.run(); MemoryItem memItem = store.get(prefix); return memItem != null && memItem.getType() == ItemType.OBJECT; } @Override public Object getValue() throws Exception { + mailboxThreadChecker.run(); MemoryItem memItem = store.get(prefix); if (memItem != null && memItem.getType() == ItemType.VALUE) { return memItem.getValue(); diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index cbc7e10..9eef4ff 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -32,6 +32,7 @@ import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; import org.apache.flink.agents.runtime.utils.EventUtil; +import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.*; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -42,10 +43,13 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; +import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; +import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor; import org.apache.flink.types.Row; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Field; import java.util.HashMap; import java.util.LinkedList; import java.util.List; @@ -87,12 +91,25 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private transient BuiltInMetrics builtInMetrics; + private transient MailboxExecutor mailboxExecutor; + + // We need to check whether the current thread is the mailbox thread using the mailbox + // processor. + // TODO: This is a temporary workaround. In the future, we should add an interface in + // MailboxExecutor to check whether a thread is a mailbox thread, rather than using reflection + // to obtain the MailboxProcessor instance and make the determination. + private transient MailboxProcessor mailboxProcessor; + public ActionExecutionOperator( - AgentPlan agentPlan, Boolean inputIsJava, ProcessingTimeService processingTimeService) { + AgentPlan agentPlan, + Boolean inputIsJava, + ProcessingTimeService processingTimeService, + MailboxExecutor mailboxExecutor) { this.agentPlan = agentPlan; this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; this.chainingStrategy = ChainingStrategy.ALWAYS; + this.mailboxExecutor = mailboxExecutor; } @Override @@ -112,10 +129,13 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); - runnerContext = new RunnerContextImpl(shortTermMemState, metricGroup); + runnerContext = + new RunnerContextImpl(shortTermMemState, metricGroup, this::checkMailboxThread); // init PythonActionExecutor initPythonActionExecutor(); + + mailboxProcessor = getMailboxProcessor(); } @Override @@ -207,7 +227,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT new PythonActionExecutor( pythonEnvironmentManager, new ObjectMapper().writeValueAsString(agentPlan)); - pythonActionExecutor.open(shortTermMemState, metricGroup); + pythonActionExecutor.open(shortTermMemState, metricGroup, this::checkMailboxThread); } } @@ -240,4 +260,16 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT return agentPlan.getActionsTriggeredBy(event.getClass().getName()); } } + + private MailboxProcessor getMailboxProcessor() throws Exception { + Field field = MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor"); + field.setAccessible(true); + return (MailboxProcessor) field.get(mailboxExecutor); + } + + private void checkMailboxThread() { + checkState( + mailboxProcessor.isMailboxThread(), + "Expected to be running on the task mailbox thread, but was not."); + } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java index 213361d..92b70a5 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java @@ -41,7 +41,10 @@ public class ActionExecutionOperatorFactory<IN, OUT> StreamOperatorParameters<OUT> parameters) { ActionExecutionOperator<IN, OUT> op = new ActionExecutionOperator<>( - agentPlan, inputIsJava, parameters.getProcessingTimeService()); + agentPlan, + inputIsJava, + parameters.getProcessingTimeService(), + parameters.getMailboxExecutor()); op.setup( parameters.getContainingTask(), parameters.getStreamConfig(), diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java index d9a20fc..b3377cb 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java @@ -18,7 +18,6 @@ package org.apache.flink.agents.runtime.python.context; import org.apache.flink.agents.api.Event; -import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; @@ -35,8 +34,9 @@ public class PythonRunnerContextImpl extends RunnerContextImpl { public PythonRunnerContextImpl( MapState<String, MemoryObjectImpl.MemoryItem> store, - FlinkAgentsMetricGroupImpl agentMetricGroup) { - super(store, agentMetricGroup); + FlinkAgentsMetricGroupImpl agentMetricGroup, + Runnable mailboxThreadChecker) { + super(store, agentMetricGroup, mailboxThreadChecker); } @Override @@ -50,9 +50,4 @@ public class PythonRunnerContextImpl extends RunnerContextImpl { // this method will be invoked by PythonActionExecutor's python interpreter. sendEvent(new PythonEvent(event, type)); } - - @Override - public MemoryObject getShortTermMemory() throws Exception { - return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY); - } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index 905ddb6..ebd8c05 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -62,7 +62,8 @@ public class PythonActionExecutor { public void open( MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState, - FlinkAgentsMetricGroupImpl metricGroup) + FlinkAgentsMetricGroupImpl metricGroup, + Runnable mailboxThreadChecker) throws Exception { environmentManager.open(); EmbeddedPythonEnvironment env = environmentManager.createEnvironment(); @@ -70,7 +71,8 @@ public class PythonActionExecutor { interpreter = env.getInterpreter(); interpreter.exec(PYTHON_IMPORTS); - runnerContext = new PythonRunnerContextImpl(shortTermMemState, metricGroup); + runnerContext = + new PythonRunnerContextImpl(shortTermMemState, metricGroup, mailboxThreadChecker); // TODO: remove the set and get runner context after updating pemja to version 0.5.3 Object pythonRunnerContextObject = diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java index e064de1..1617a19 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java @@ -41,7 +41,7 @@ public class CompileUtilsTest { private static final Long TEST_SEQUENCE_END = 100L; // Agent logic: x -> (x + 1) * 2 private static final AgentPlan TEST_AGENT_PLAN = - ActionExecutionOperatorTest.TestAgent.getAgentPlan(); + ActionExecutionOperatorTest.TestAgent.getAgentPlan(false); @Test void testJavaNoKeyedStreamConnectToAgent() throws Exception { diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 7987bcd..2559bab 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -28,7 +28,6 @@ import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.util.ExceptionUtils; import org.junit.jupiter.api.Test; @@ -37,20 +36,21 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** Tests for {@link ActionExecutionOperator}. */ public class ActionExecutionOperatorTest { @Test void testExecuteAgent() throws Exception { - ActionExecutionOperator<Long, Object> operator = - new ActionExecutionOperator<>( - TestAgent.getAgentPlan(), true, new TestProcessingTimeService()); try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( - operator, + new ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true), (KeySelector<Long, Long>) value -> value, TypeInformation.of(Long.class))) { testHarness.open(); @@ -67,6 +67,21 @@ public class ActionExecutionOperatorTest { } } + @Test + void testMemoryAccessProhibitedOutsideMailboxThread() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(TestAgent.getAgentPlan(true), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + + assertThatThrownBy(() -> testHarness.processElement(new StreamRecord<>(0L))) + .rootCause() + .hasMessageContaining("Expected to be running on the task mailbox thread"); + } + } + public static class TestAgent { public static class MiddleEvent extends Event { @@ -82,7 +97,7 @@ public class ActionExecutionOperatorTest { } } - public static void processInputEvent(InputEvent event, RunnerContext context) { + public static void action1(InputEvent event, RunnerContext context) { Long inputData = (Long) event.getInput(); try { MemoryObject mem = context.getShortTermMemory(); @@ -93,7 +108,7 @@ public class ActionExecutionOperatorTest { context.sendEvent(new MiddleEvent(inputData + 1)); } - public static void processMiddleEvent(MiddleEvent event, RunnerContext context) { + public static void action2(MiddleEvent event, RunnerContext context) { try { MemoryObject mem = context.getShortTermMemory(); Long tmp = (Long) mem.get("tmp").getValue(); @@ -103,23 +118,37 @@ public class ActionExecutionOperatorTest { } } - public static AgentPlan getAgentPlan() { + public static void action3(MiddleEvent event, RunnerContext context) { + // To test disallows memory access from non-mailbox threads. + try { + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future<Long> future = + executor.submit( + () -> (Long) context.getShortTermMemory().get("tmp").getValue()); + Long tmp = future.get(); + context.sendEvent(new OutputEvent(tmp * 2)); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + } + + public static AgentPlan getAgentPlan(boolean testMemoryAccessOutOfMailbox) { try { Map<String, List<Action>> actionsByEvent = new HashMap<>(); Action action1 = new Action( - "processInputEvent", + "action1", new JavaFunction( TestAgent.class, - "processInputEvent", + "action1", new Class<?>[] {InputEvent.class, RunnerContext.class}), Collections.singletonList(InputEvent.class.getName())); Action action2 = new Action( - "processMiddleEvent", + "action2", new JavaFunction( TestAgent.class, - "processMiddleEvent", + "action2", new Class<?>[] {MiddleEvent.class, RunnerContext.class}), Collections.singletonList(MiddleEvent.class.getName())); actionsByEvent.put(InputEvent.class.getName(), Collections.singletonList(action1)); @@ -127,6 +156,23 @@ public class ActionExecutionOperatorTest { Map<String, Action> actions = new HashMap<>(); actions.put(action1.getName(), action1); actions.put(action2.getName(), action2); + + if (testMemoryAccessOutOfMailbox) { + Action action3 = + new Action( + "action3", + new JavaFunction( + TestAgent.class, + "action3", + new Class<?>[] { + MiddleEvent.class, RunnerContext.class + }), + Collections.singletonList(MiddleEvent.class.getName())); + actionsByEvent.put( + MiddleEvent.class.getName(), Collections.singletonList(action3)); + actions.put(action3.getName(), action3); + } + return new AgentPlan(actions, actionsByEvent); } catch (Exception e) { ExceptionUtils.rethrow(e);