This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit a50137817a10c2bf192cfa6e5a97f613fd640837 Author: Xu Huang <zuosi...@alibaba-inc.com> AuthorDate: Sat Jul 26 10:08:29 2025 +0800 [runtime] Implement async execution in ActionExecutionOperator for Python actions --- .../org/apache/flink/agents/plan/JavaFunction.java | 9 +- python/flink_agents/examples/my_agent.py | 12 +- python/flink_agents/plan/function.py | 81 +++++- .../create_python_agent_plan_from_json.py | 2 +- python/flink_agents/plan/tests/test_function.py | 5 +- .../flink_agents/runtime/flink_runner_context.py | 37 ++- .../runtime/operator/ActionExecutionOperator.java | 301 ++++++++++++++++----- .../flink/agents/runtime/operator/ActionTask.java | 114 ++++++++ .../agents/runtime/operator/JavaActionTask.java | 50 ++++ .../runtime/python/operator/PythonActionTask.java | 74 +++++ .../python/operator/PythonGeneratorActionTask.java | 50 ++++ .../runtime/python/utils/PythonActionExecutor.java | 119 ++++++-- .../flink/agents/runtime/utils/StateUtil.java | 91 +++++++ .../flink/agents/runtime/CompileUtilsTest.java | 37 ++- .../operator/ActionExecutionOperatorTest.java | 91 ++++++- 15 files changed, 944 insertions(+), 129 deletions(-) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java b/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java index 580004c..37f7f37 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java @@ -38,7 +38,7 @@ public class JavaFunction implements Function { @JsonProperty(FIELD_NAME_PARAMETER_TYPES) private final Class<?>[] parameterTypes; - @JsonIgnore private final Method method; + @JsonIgnore private transient Method method; public JavaFunction( @JsonProperty(FIELD_NAME_QUAL_NAME) String qualName, @@ -71,7 +71,10 @@ public class JavaFunction implements Function { return parameterTypes; } - public Method getMethod() { + public Method getMethod() throws ClassNotFoundException, NoSuchMethodException { + if (method == null) { + this.method = Class.forName(qualName).getMethod(methodName, parameterTypes); + } return method; } @@ -93,7 +96,7 @@ public class JavaFunction implements Function { @Override public Object call(Object... args) throws Exception { - return method.invoke(null, args); + return getMethod().invoke(null, args); } @Override diff --git a/python/flink_agents/examples/my_agent.py b/python/flink_agents/examples/my_agent.py index b3947f3..b810d64 100644 --- a/python/flink_agents/examples/my_agent.py +++ b/python/flink_agents/examples/my_agent.py @@ -16,6 +16,8 @@ # limitations under the License. ################################################################################# import copy +import random +import time from typing import Any, Optional from pydantic import BaseModel @@ -60,6 +62,12 @@ class DataStreamAgent(Agent): @action(InputEvent) @staticmethod def first_action(event: Event, ctx: RunnerContext): # noqa D102 + def log_to_stdout(input: Any, total: int) -> bool: + # Simulating asynchronous time consumption + time.sleep(random.random()) + print(f"[log_to_stdout] Logging input={input}, total reviews now={total}") + return True + input = event.input stm = ctx.get_short_term_memory() @@ -71,8 +79,10 @@ class DataStreamAgent(Agent): total += 1 status.set("total_reviews", total) + log_success = yield from ctx.execute_async(log_to_stdout, input, total) + content = copy.deepcopy(input) - content.review += " first action" + content.review += " first action, log success=" + str(log_success) + "," ctx.send_event(MyEvent(value=content)) @action(MyEvent) diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index ebb2df8..032eec6 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -16,10 +16,11 @@ # limitations under the License. ################################################################################# +import logging import importlib import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Tuple, get_type_hints +from typing import Any, Callable, Dict, List, Tuple, get_type_hints, Generator from pydantic import BaseModel @@ -28,6 +29,11 @@ from flink_agents.plan.utils import check_type_match # Global cache for PythonFunction instances to avoid repeated creation _PYTHON_FUNCTION_CACHE: Dict[Tuple[str, str], "PythonFunction"] = {} +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + def _is_function_cacheable(func: Callable) -> bool: """Check if a function is safe to cache. @@ -237,6 +243,28 @@ class JavaFunction(Function): """Check function signature is legal or not.""" +class PythonGeneratorWrapper: + """ + A temporary wrapper class for Python generators to work around a + known issue in PEMJA, where the generator type is incorrectly handled. + + TODO: This wrapper is intended to be a temporary solution. Once PEMJA + version 0.5.5 (or later) fixes the bug related to generator type conversion, + this wrapper should be removed. For more details, please refer to + https://github.com/apache/flink-agents/issues/83. + """ + + def __init__(self, generator: Generator) -> None: + """Initialize a PythonGeneratorWrapper. """ + self.generator = generator + + def __str__(self) -> str: + return "PythonGeneratorWrapper, generator=" + str(self.generator) + + def __next__(self) -> Any: + return next(self.generator) + + def call_python_function(module: str, qualname: str, func_args: Tuple[Any, ...]) -> Any: """Used to call a Python function in the Pemja environment. @@ -260,19 +288,19 @@ def call_python_function(module: str, qualname: str, func_args: Tuple[Any, ...]) """ cache_key = (module, qualname) + python_func = None + if cache_key not in _PYTHON_FUNCTION_CACHE: python_func = PythonFunction(module=module, qualname=qualname) - try: - if python_func.is_cacheable(): - _PYTHON_FUNCTION_CACHE[cache_key] = python_func - else: - return python_func(*func_args) - except Exception: - return python_func(*func_args) + if python_func.is_cacheable(): + _PYTHON_FUNCTION_CACHE[cache_key] = python_func + else: + python_func = _PYTHON_FUNCTION_CACHE[cache_key] - # Use cached instance - func = _PYTHON_FUNCTION_CACHE[cache_key] - return func(*func_args) + func_result = python_func(*func_args) + if isinstance(func_result, Generator): + return PythonGeneratorWrapper(func_result) + return func_result def clear_python_function_cache() -> None: @@ -305,3 +333,34 @@ def get_python_function_cache_keys() -> List[Tuple[str, str]]: List of (module, qualname) tuples representing cached functions. """ return list(_PYTHON_FUNCTION_CACHE.keys()) + + +def call_python_generator(generator_wrapper: PythonGeneratorWrapper) -> (bool, Any): + """ + Invokes the next step of a wrapped Python generator and returns whether + it is done, along with the yielded or returned value. + + Args: + generator_wrapper (PythonGeneratorWrapper): A wrapper object that + contains a `generator` attribute. This attribute should be an instance + of a Python generator. + + Returns: + Tuple[bool, Any]: + - The first element is a boolean flag indicating whether the generator + has finished: + * False: The generator has more values to yield. + * True: The generator has completed. + - The second element is either: + * The value yielded by the generator (when not exhausted), or + * The return value of the generator (when it has finished). + """ + try: + result = next(generator_wrapper.generator) + except StopIteration as e: + return True, e.value if hasattr(e, 'value') else None + except Exception: + logger.exception("Error in generator execution") + raise + else: + return False, result diff --git a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py index 828ac7c..6ff356d 100644 --- a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py +++ b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py @@ -75,7 +75,7 @@ if __name__ == "__main__": assert len(actions_by_event) == 2 assert input_event in actions_by_event - assert actions_by_event[input_event] == ["firstAction", "secondAction"] + assert sorted(actions_by_event[input_event]) == ["firstAction", "secondAction"] assert my_event in actions_by_event assert actions_by_event[my_event] == ["secondAction"] diff --git a/python/flink_agents/plan/tests/test_function.py b/python/flink_agents/plan/tests/test_function.py index 945401d..f79e7bb 100644 --- a/python/flink_agents/plan/tests/test_function.py +++ b/python/flink_agents/plan/tests/test_function.py @@ -30,7 +30,7 @@ from flink_agents.plan.function import ( call_python_function, clear_python_function_cache, get_python_function_cache_keys, - get_python_function_cache_size, + get_python_function_cache_size, PythonGeneratorWrapper, ) if TYPE_CHECKING: @@ -314,8 +314,9 @@ def test_selective_caching_generator_functions() -> None: result = call_python_function( "flink_agents.plan.tests.test_function", "generator_function", (3,) ) + assert isinstance(result, PythonGeneratorWrapper) # Convert generator to list for testing - result_list = list(result) + result_list = list(result.generator) assert result_list == [0, 1, 2] # Should not be cached diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 99745c3..5f3bb90 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -26,6 +26,8 @@ from flink_agents.api.runner_context import RunnerContext from flink_agents.plan.agent_plan import AgentPlan from flink_agents.runtime.flink_memory_object import FlinkMemoryObject from flink_agents.runtime.flink_metric_group import FlinkMetricGroup +from concurrent.futures import ThreadPoolExecutor +import os class FlinkRunnerContext(RunnerContext): @@ -36,7 +38,9 @@ class FlinkRunnerContext(RunnerContext): __agent_plan: AgentPlan - def __init__(self, j_runner_context: Any, agent_plan_json: str) -> None: + def __init__( + self, j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor + ) -> None: """Initialize a flink runner context with the given java runner context. Parameters @@ -46,6 +50,7 @@ class FlinkRunnerContext(RunnerContext): """ self._j_runner_context = j_runner_context self.__agent_plan = AgentPlan.model_validate_json(agent_plan_json) + self.executor = executor @override def send_event(self, event: Event) -> None: @@ -115,9 +120,29 @@ class FlinkRunnerContext(RunnerContext): """Asynchronously execute the provided function. Access to memory is prohibited within the function. """ - # TODO: Implement in a future commit. - pass - -def create_flink_runner_context(j_runner_context: Any, agent_plan_json: str) -> FlinkRunnerContext: + future = self.executor.submit(func, *args, **kwargs) + while not future.done(): + # TODO: Currently, we are using a polling mechanism to check whether + # the future has completed. This approach should be optimized in the + # future by switching to a notification-based model, where the Flink + # operator is notified directly once the future is completed. + yield + return future.result() + + +def create_flink_runner_context( + j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor +) -> FlinkRunnerContext: """Used to create a FlinkRunnerContext Python object in Pemja environment.""" - return FlinkRunnerContext(j_runner_context, agent_plan_json) + return FlinkRunnerContext(j_runner_context, agent_plan_json, executor) + + +def create_async_thread_pool() -> ThreadPoolExecutor: + """Used to create a thread pool to execute asynchronous + code block in action.""" + return ThreadPoolExecutor(max_workers=os.cpu_count() * 2) + + +def close_async_thread_pool(executor: ThreadPoolExecutor) -> None: + """Used to close the thread pool.""" + executor.shutdown() diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 9eef4ff..85544cd 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -29,16 +29,22 @@ import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.metrics.BuiltInMetrics; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; +import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.agents.runtime.python.operator.PythonActionTask; import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; import org.apache.flink.agents.runtime.utils.EventUtil; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.*; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.python.env.PythonDependencyInfo; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -46,14 +52,19 @@ import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor; import org.apache.flink.types.Row; +import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.reflect.Field; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; +import java.util.Optional; +import static org.apache.flink.agents.runtime.utils.StateUtil.listStateNotEmpty; +import static org.apache.flink.agents.runtime.utils.StateUtil.pollFromListState; +import static org.apache.flink.agents.runtime.utils.StateUtil.removeFromListState; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** @@ -67,7 +78,7 @@ import static org.apache.flink.util.Preconditions.checkState; * and the resulting output event is collected for further processing. */ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT> - implements OneInputStreamOperator<IN, OUT> { + implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { private static final long serialVersionUID = 1L; @@ -81,9 +92,6 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private transient MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState; - // RunnerContext for Java actions - private transient RunnerContextImpl runnerContext; - // PythonActionExecutor for Python actions private transient PythonActionExecutor pythonActionExecutor; @@ -91,7 +99,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private transient BuiltInMetrics builtInMetrics; - private transient MailboxExecutor mailboxExecutor; + private final transient MailboxExecutor mailboxExecutor; // We need to check whether the current thread is the mailbox thread using the mailbox // processor. @@ -100,6 +108,19 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // to obtain the MailboxProcessor instance and make the determination. private transient MailboxProcessor mailboxProcessor; + // An action will be split into one or more ActionTask objects. We use a state to store the + // pending ActionTasks that are waiting to be executed. + private transient ListState<ActionTask> actionTasksKState; + + // To avoid processing different InputEvents with the same key, we use a state to store pending + // InputEvents that are waiting to be processed. + private transient ListState<Event> pendingInputEventsKState; + + // An operator state is used to track the currently processing keys. This is useful when + // receiving an EndOfInput signal, as we need to wait until all related events are fully + // processed. + private transient ListState<Object> currentProcessingKeysOpState; + public ActionExecutionOperator( AgentPlan agentPlan, Boolean inputIsJava, @@ -129,13 +150,37 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); - runnerContext = - new RunnerContextImpl(shortTermMemState, metricGroup, this::checkMailboxThread); + // init agent processing related state + actionTasksKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + "pendingInputEvents", TypeInformation.of(Event.class))); + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do + // not belong to the key range of current task. + currentProcessingKeysOpState = + getOperatorStateBackend() + .getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor initPythonActionExecutor(); mailboxProcessor = getMailboxProcessor(); + + // Since an operator restart may change the key range it manages due to changes in + // parallelism, + // and {@link tryProcessActionTaskForKey} mails might be lost, + // it is necessary to reprocess all keys to ensure correctness. + tryResumeProcessActionTasks(); } @Override @@ -143,65 +188,120 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT IN input = record.getValue(); LOG.debug("Receive an element {}", input); - // 1. wrap to InputEvent first + // wrap to InputEvent first Event inputEvent = wrapToInputEvent(input); - // 2. execute action - LinkedList<Event> events = new LinkedList<>(); - events.push(inputEvent); - while (!events.isEmpty()) { - Event event = events.pop(); - builtInMetrics.markEventProcessed(); - List<Action> actions = getActionsTriggeredBy(event); - if (actions != null && !actions.isEmpty()) { - for (Action action : actions) { - // TODO: Support multi-action execution for a single event. Example: A Java - // event - // should be processable by both Java and Python actions. - // TODO: Implement asynchronous action execution. - - // execute action and collect output events - String actionName = action.getName(); - LOG.debug("Try execute action {} for event {}.", actionName, event); - List<Event> actionOutputEvents; - if (action.getExec() instanceof JavaFunction) { - runnerContext.setActionName(actionName); - action.getExec().call(event, runnerContext); - actionOutputEvents = runnerContext.drainEvents(); - } else if (action.getExec() instanceof PythonFunction) { - checkState(event instanceof PythonEvent); - actionOutputEvents = - pythonActionExecutor.executePythonFunction( - (PythonFunction) action.getExec(), - (PythonEvent) event, - actionName); - } else { - throw new RuntimeException("Unsupported action type: " + action.getClass()); - } - builtInMetrics.markActionExecuted(actionName); - - for (Event actionOutputEvent : actionOutputEvents) { - if (EventUtil.isOutputEvent(actionOutputEvent)) { - builtInMetrics.markEventProcessed(); - OUT outputData = getOutputFromOutputEvent(actionOutputEvent); - LOG.debug( - "Collect output data {} for input {} in action {}.", - outputData, - input, - action.getName()); - output.collect(reusedStreamRecord.replace(outputData)); - } else { - LOG.debug( - "Collect event {} for event {} in action {}.", - actionOutputEvent, - event, - action.getName()); - events.add(actionOutputEvent); - } - } + if (currentKeyHasMoreActionTask()) { + // If there are already actions being processed for the current key, the newly incoming + // event should be queued and processed later. Therefore, we add it to + // pendingInputEventsState. + pendingInputEventsKState.add(inputEvent); + } else { + // Otherwise, the new event is processed immediately. + processEvent(getCurrentKey(), inputEvent); + } + } + + /** + * Processes an incoming event for the given key and may submit a new mail + * `tryProcessActionTaskForKey` to continue processing. + */ + private void processEvent(Object key, Event event) throws Exception { + boolean isInputEvent = EventUtil.isInputEvent(event); + builtInMetrics.markEventProcessed(); + if (EventUtil.isOutputEvent(event)) { + // If the event is an OutputEvent, we send it downstream. + OUT outputData = getOutputFromOutputEvent(event); + output.collect(reusedStreamRecord.replace(outputData)); + } else { + if (isInputEvent) { + // If the event is an InputEvent, we mark that the key is currently being processed. + currentProcessingKeysOpState.add(key); + } + // We then obtain the triggered action and add ActionTasks to the waiting processing + // queue. + List<Action> triggerActions = getActionsTriggeredBy(event); + if (triggerActions != null && !triggerActions.isEmpty()) { + for (Action triggerAction : triggerActions) { + actionTasksKState.add(createActionTask(key, triggerAction, event)); } } } + + if (isInputEvent) { + // If the event is an InputEvent, we submit a new mail to try processing the actions. + mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); + } + } + + private void tryProcessActionTaskForKey(Object key) { + try { + processActionTaskForKey(key); + } catch (Exception e) { + mailboxExecutor.execute( + () -> + ExceptionUtils.rethrow( + new ActionTaskExecutionException( + "Failed to execute action task", e)), + "throw exception in mailbox"); + } + } + + private void processActionTaskForKey(Object key) throws Exception { + // 1. Get an action task for the key. + setCurrentKey(key); + ActionTask actionTask = pollFromListState(actionTasksKState); + if (actionTask == null) { + int removedCount = removeFromListState(currentProcessingKeysOpState, key); + checkState( + removedCount == 1, + "Current processing key count for key " + + key + + " should be 1, but got " + + removedCount); + return; + } + + // 2. Invoke the action task. + createAndSetRunnerContext(actionTask); + ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke(); + for (Event actionOutputEvent : actionTaskResult.getOutputEvents()) { + processEvent(key, actionOutputEvent); + } + + boolean currentInputEventFinished = false; + if (actionTaskResult.isFinished()) { + builtInMetrics.markActionExecuted(actionTask.action.getName()); + currentInputEventFinished = !currentKeyHasMoreActionTask(); + } else { + // If the action task not finished, we should get a new action task to execute continue. + Optional<ActionTask> generatedActionTaskOpt = actionTaskResult.getGeneratedActionTask(); + checkNotNull( + generatedActionTaskOpt.isPresent(), + "ActionTask not finished, but the generated action task is null."); + actionTasksKState.add(generatedActionTaskOpt.get()); + } + + // 3. Process the next InputEvent or next action task + if (currentInputEventFinished) { + // Once all sub-events and actions related to the current InputEvent are completed, + // we can proceed to process the next InputEvent. + int removedCount = removeFromListState(currentProcessingKeysOpState, key); + checkState( + removedCount == 1, + "Current processing key count for key " + + key + + " should be 1, but got " + + removedCount); + Event pendingInputEvent = pollFromListState(pendingInputEventsKState); + if (pendingInputEvent != null) { + processEvent(key, pendingInputEvent); + } + } else if (currentKeyHasMoreActionTask()) { + // If the current key has additional action tasks remaining, we should submit a new mail + // to continue processing them. + mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), "process action task"); + } } private void initPythonActionExecutor() throws Exception { @@ -227,10 +327,31 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT new PythonActionExecutor( pythonEnvironmentManager, new ObjectMapper().writeValueAsString(agentPlan)); - pythonActionExecutor.open(shortTermMemState, metricGroup, this::checkMailboxThread); + pythonActionExecutor.open(); + } + } + + @Override + public void endInput() throws Exception { + waitInFlightEventsFinished(); + } + + @VisibleForTesting + public void waitInFlightEventsFinished() throws Exception { + while (listStateNotEmpty(currentProcessingKeysOpState)) { + mailboxExecutor.yield(); } } + @Override + public void close() throws Exception { + if (pythonActionExecutor != null) { + pythonActionExecutor.close(); + } + + super.close(); + } + private Event wrapToInputEvent(IN input) { if (inputIsJava) { return new InputEvent(input); @@ -272,4 +393,58 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT mailboxProcessor.isMailboxThread(), "Expected to be running on the task mailbox thread, but was not."); } + + private ActionTask createActionTask(Object key, Action action, Event event) { + if (action.getExec() instanceof JavaFunction) { + return new JavaActionTask(key, event, action); + } else if (action.getExec() instanceof PythonFunction) { + return new PythonActionTask(key, event, action, pythonActionExecutor); + } else { + throw new IllegalStateException( + "Unsupported action type: " + action.getExec().getClass()); + } + } + + private void createAndSetRunnerContext(ActionTask actionTask) { + if (actionTask.getRunnerContext() != null) { + return; + } + + RunnerContextImpl runnerContext; + if (actionTask.action.getExec() instanceof JavaFunction) { + runnerContext = + new RunnerContextImpl(shortTermMemState, metricGroup, this::checkMailboxThread); + } else if (actionTask.action.getExec() instanceof PythonFunction) { + runnerContext = + new PythonRunnerContextImpl( + shortTermMemState, metricGroup, this::checkMailboxThread); + } else { + throw new IllegalStateException( + "Unsupported action type: " + actionTask.action.getExec().getClass()); + } + + runnerContext.setActionName(actionTask.action.getName()); + actionTask.setRunnerContext(runnerContext); + } + + private boolean currentKeyHasMoreActionTask() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + + private void tryResumeProcessActionTasks() throws Exception { + Iterable<Object> keys = currentProcessingKeysOpState.get(); + if (keys != null) { + for (Object key : keys) { + mailboxExecutor.submit( + () -> tryProcessActionTaskForKey(key), "process action task"); + } + } + } + + /** Failed to execute Action task. */ + public static class ActionTaskExecutionException extends Exception { + public ActionTaskExecutionException(String message, Throwable cause) { + super(message, cause); + } + } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java new file mode 100644 index 0000000..60ecab0 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; + +/** + * This class represents a task related to the execution of an action in {@link + * ActionExecutionOperator}. + * + * <p>An action is split into multiple code blocks, and each code block is represented by an {@code + * ActionTask}. You can call {@link #invoke()} to execute a code block and obtain invoke result + * {@link ActionTaskResult}. If the action contains additional code blocks, you can obtain the next + * {@code ActionTask} via {@link ActionTaskResult#getGeneratedActionTask()} and continue executing + * it. + */ +public abstract class ActionTask { + + protected static final Logger LOG = LoggerFactory.getLogger(ActionTask.class); + + protected final Object key; + protected final Event event; + protected final Action action; + /** + * Since RunnerContextImpl contains references to the Operator and state, it should not be + * serialized and included in the state with ActionTask. Instead, we should check if a valid + * RunnerContext exists before each ActionTask invocation and create a new one if necessary. + */ + protected transient RunnerContextImpl runnerContext; + + public ActionTask(Object key, Event event, Action action) { + this.key = key; + this.event = event; + this.action = action; + } + + public RunnerContextImpl getRunnerContext() { + return runnerContext; + } + + public void setRunnerContext(RunnerContextImpl runnerContext) { + this.runnerContext = runnerContext; + } + + public Object getKey() { + return key; + } + + /** Invokes the action task. */ + public abstract ActionTaskResult invoke() throws Exception; + + public class ActionTaskResult { + private final boolean finished; + private final List<Event> outputEvents; + private final Optional<ActionTask> generatedActionTaskOpt; + + public ActionTaskResult( + boolean finished, + List<Event> outputEvents, + @Nullable ActionTask generatedActionTask) { + this.finished = finished; + this.outputEvents = outputEvents; + this.generatedActionTaskOpt = Optional.ofNullable(generatedActionTask); + } + + public boolean isFinished() { + return finished; + } + + public List<Event> getOutputEvents() { + return outputEvents; + } + + public Optional<ActionTask> getGeneratedActionTask() { + return generatedActionTaskOpt; + } + + @Override + public String toString() { + return "ActionTaskResult{" + + "finished=" + + finished + + ", outputEvents=" + + outputEvents + + ", generatedActionTaskOpt=" + + generatedActionTaskOpt + + '}'; + } + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java new file mode 100644 index 0000000..71d04b8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.plan.JavaFunction; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A special {@link ActionTask} designed to execute a Java action task. + * + * <p>Note that Java action currently do not support asynchronous execution. As a result, a Java + * action task will be invoked only once. + */ +public class JavaActionTask extends ActionTask { + + public JavaActionTask(Object key, Event event, Action action) { + super(key, event, action); + checkState(action.getExec() instanceof JavaFunction); + } + + @Override + public ActionTaskResult invoke() throws Exception { + LOG.debug( + "Try execute java action {} for event {} with key {}.", + action.getName(), + event, + key); + runnerContext.checkNoPendingEvents(); + action.getExec().call(event, runnerContext); + return new ActionTaskResult(true, runnerContext.drainEvents(), null); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java new file mode 100644 index 0000000..53fb93b --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.python.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.operator.ActionTask; +import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * A special {@link ActionTask} designed to execute a Python action task. + * + * <p>During asynchronous execution in Python, the {@link PythonActionTask} can produce a {@link + * PythonGeneratorActionTask} to represent the subsequent code block when needed. + */ +public class PythonActionTask extends ActionTask { + + protected final PythonActionExecutor pythonActionExecutor; + + public PythonActionTask( + Object key, Event event, Action action, PythonActionExecutor pythonActionExecutor) { + super(key, event, action); + checkState(action.getExec() instanceof PythonFunction); + checkState( + event instanceof PythonEvent, + "Python action only accept python event, but got " + event); + this.pythonActionExecutor = pythonActionExecutor; + } + + public ActionTaskResult invoke() throws Exception { + LOG.debug( + "Try execute python action {} for event {} with key {}.", + action.getName(), + event, + key); + runnerContext.checkNoPendingEvents(); + + String pythonGeneratorRef = + pythonActionExecutor.executePythonFunction( + (PythonFunction) action.getExec(), (PythonEvent) event, runnerContext); + // If a user-defined action uses an interface to submit asynchronous tasks, it will return a + // Python generator object instance upon its first execution. Otherwise, it means that no + // asynchronous tasks were submitted and the action has already completed. + if (pythonGeneratorRef != null) { + // The Python action generates a generator. We need to execute it once, which will + // submit an asynchronous task and return whether the action has been completed. + ActionTask tempGeneratedActionTask = + new PythonGeneratorActionTask( + key, event, action, pythonActionExecutor, pythonGeneratorRef); + tempGeneratedActionTask.setRunnerContext(runnerContext); + return tempGeneratedActionTask.invoke(); + } + return new ActionTaskResult(true, runnerContext.drainEvents(), null); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java new file mode 100644 index 0000000..bc13f20 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.python.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.runtime.operator.ActionTask; +import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor; + +/** An {@link ActionTask} wrapper a Python Generator to represent a code block in Python action. */ +public class PythonGeneratorActionTask extends PythonActionTask { + private final String pythonGeneratorRef; + + public PythonGeneratorActionTask( + Object key, + Event event, + Action action, + PythonActionExecutor pythonActionExecutor, + String pythonGeneratorRef) { + super(key, event, action, pythonActionExecutor); + this.pythonGeneratorRef = pythonGeneratorRef; + } + + @Override + public ActionTaskResult invoke() throws Exception { + LOG.debug( + "Try execute python generator action {} for event {} with key {}.", + action.getName(), + event, + key); + boolean finished = pythonActionExecutor.callPythonGenerator(pythonGeneratorRef); + ActionTask generatedActionTask = finished ? null : this; + return new ActionTaskResult(finished, runnerContext.drainEvents(), generatedActionTask); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java index ebd8c05..09b38ee 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java @@ -17,19 +17,15 @@ */ package org.apache.flink.agents.runtime.python.utils; -import org.apache.flink.agents.api.Event; import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; -import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; -import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; -import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl; import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.agents.runtime.utils.EventUtil; -import org.apache.flink.api.common.state.MapState; import pemja.core.PythonInterpreter; -import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import static org.apache.flink.util.Preconditions.checkState; @@ -40,65 +36,107 @@ public class PythonActionExecutor { "from flink_agents.plan import function\n" + "from flink_agents.runtime import flink_runner_context\n" + "from flink_agents.runtime import python_java_utils"; + + // =========== RUNNER CONTEXT =========== private static final String CREATE_FLINK_RUNNER_CONTEXT = "flink_runner_context.create_flink_runner_context"; + private static final String FLINK_RUNNER_CONTEXT_REF_NAME_PREFIX = "flink_runner_context_"; + private static final AtomicLong FLINK_RUNNER_CONTEXT_REF_ID = new AtomicLong(0); + + // ========== ASYNC THREAD POOL =========== + private static final String CREATE_ASYNC_THREAD_POOL = + "flink_runner_context.create_async_thread_pool"; + private static final String CLOSE_ASYNC_THREAD_POOL = + "flink_runner_context.close_async_thread_pool"; + private static final String PYTHON_ASYNC_THREAD_POOL_REF_NAME = "python_async_thread_pool"; + private static final AtomicLong PYTHON_ASYNC_THREAD_POOL_REF_ID = new AtomicLong(0); + + // =========== PYTHON GENERATOR =========== + private static final String CALL_PYTHON_GENERATOR = "function.call_python_generator"; + private static final String PYTHON_GENERATOR_VAR_NAME_PREFIX = "python_generator_"; + private static final AtomicLong PYTHON_GENERATOR_VAR_ID = new AtomicLong(0); + + // =========== PYTHON AND JAVA OBJECT CONVERT =========== private static final String CONVERT_TO_PYTHON_OBJECT = "python_java_utils.convert_to_python_object"; private static final String WRAP_TO_INPUT_EVENT = "python_java_utils.wrap_to_input_event"; private static final String GET_OUTPUT_FROM_OUTPUT_EVENT = "python_java_utils.get_output_from_output_event"; - private static final String FLINK_RUNNER_CONTEXT_VAR_NAME = "flink_runner_context"; private final PythonEnvironmentManager environmentManager; private final String agentPlanJson; - private PythonRunnerContextImpl runnerContext; - private PythonInterpreter interpreter; + private String pythonAsyncThreadPoolObjectName; public PythonActionExecutor(PythonEnvironmentManager environmentManager, String agentPlanJson) { this.environmentManager = environmentManager; this.agentPlanJson = agentPlanJson; } - public void open( - MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState, - FlinkAgentsMetricGroupImpl metricGroup, - Runnable mailboxThreadChecker) - throws Exception { + public void open() throws Exception { environmentManager.open(); EmbeddedPythonEnvironment env = environmentManager.createEnvironment(); interpreter = env.getInterpreter(); interpreter.exec(PYTHON_IMPORTS); - runnerContext = - new PythonRunnerContextImpl(shortTermMemState, metricGroup, mailboxThreadChecker); - - // TODO: remove the set and get runner context after updating pemja to version 0.5.3 - Object pythonRunnerContextObject = - interpreter.invoke(CREATE_FLINK_RUNNER_CONTEXT, runnerContext, agentPlanJson); - interpreter.set(FLINK_RUNNER_CONTEXT_VAR_NAME, pythonRunnerContextObject); + // TODO: remove the set and get thread pool after updating pemja to version 0.5.3. For more + // details, please refer to + // https://github.com/apache/flink-agents/issues/83. + Object pythonAsyncThreadPool = interpreter.invoke(CREATE_ASYNC_THREAD_POOL); + this.pythonAsyncThreadPoolObjectName = + PYTHON_ASYNC_THREAD_POOL_REF_NAME + + PYTHON_ASYNC_THREAD_POOL_REF_ID.incrementAndGet(); + interpreter.set(pythonAsyncThreadPoolObjectName, pythonAsyncThreadPool); } - public List<Event> executePythonFunction( - PythonFunction function, PythonEvent event, String actionName) throws Exception { + /** + * Execute the Python function, which may return a Python generator that needs to be processed + * in the future. Due to an issue in Pemja regarding incorrect object reference counting, this + * may lead to garbage collection of the object. To prevent this, we use the set and get methods + * to manually increment the object's reference count, then return the name of the Python + * generator variable. + * + * @return The name of the Python generator variable. It may be null if the Python function does + * not return a generator. + */ + public String executePythonFunction( + PythonFunction function, PythonEvent event, RunnerContextImpl runnerContext) + throws Exception { runnerContext.checkNoPendingEvents(); - runnerContext.setActionName(actionName); function.setInterpreter(interpreter); - // TODO: remove the set and get runner context after updating pemja to version 0.5.3 - Object pythonRunnerContextObject = interpreter.get(FLINK_RUNNER_CONTEXT_VAR_NAME); + // TODO: remove the set and get runner context after updating pemja to version 0.5.3. For + // more details, please refer to https://github.com/apache/flink-agents/issues/83. + Object pythonRunnerContextObject = + interpreter.invoke( + CREATE_FLINK_RUNNER_CONTEXT, + runnerContext, + agentPlanJson, + interpreter.get(pythonAsyncThreadPoolObjectName)); + String pythonRunnerContextObjectName = + FLINK_RUNNER_CONTEXT_REF_NAME_PREFIX + + FLINK_RUNNER_CONTEXT_REF_ID.incrementAndGet(); + interpreter.set(pythonRunnerContextObjectName, pythonRunnerContextObject); Object pythonEventObject = interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent()); try { - function.call(pythonEventObject, pythonRunnerContextObject); + Object calledResult = function.call(pythonEventObject, pythonRunnerContextObject); + if (calledResult == null) { + return null; + } else { + // must be a generator + String pythonGeneratorRef = + PYTHON_GENERATOR_VAR_NAME_PREFIX + + PYTHON_GENERATOR_VAR_ID.incrementAndGet(); + interpreter.set(pythonGeneratorRef, calledResult); + return pythonGeneratorRef; + } } catch (Exception e) { runnerContext.drainEvents(); throw new PythonActionExecutionException("Failed to execute Python action", e); } - - return runnerContext.drainEvents(); } public PythonEvent wrapToInputEvent(Object eventData) { @@ -112,6 +150,29 @@ public class PythonActionExecutor { return interpreter.invoke(GET_OUTPUT_FROM_OUTPUT_EVENT, pythonOutputEvent); } + /** + * Invokes the next step of a Python generator. + * + * <p>This method is typically used after initializing or resuming a Python generator that was + * created via a user-defined action involving asynchronous execution. + * + * @param pythonGeneratorRef the reference name of the Python generator object stored in the + * interpreter's context + * @return true if the generator has completed; false otherwise + */ + public boolean callPythonGenerator(String pythonGeneratorRef) { + // Calling next(generator) in Python returns a tuple of (finished, output). + Object pythonGenerator = interpreter.get(pythonGeneratorRef); + Object invokeResult = interpreter.invoke(CALL_PYTHON_GENERATOR, pythonGenerator); + checkState(invokeResult.getClass().isArray() && ((Object[]) invokeResult).length == 2); + return (boolean) ((Object[]) invokeResult)[0]; + } + + public void close() throws Exception { + interpreter.invoke( + CLOSE_ASYNC_THREAD_POOL, interpreter.get(pythonAsyncThreadPoolObjectName)); + } + /** Failed to execute Python action. */ public static class PythonActionExecutionException extends Exception { public PythonActionExecutionException(String message, Throwable cause) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java b/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java new file mode 100644 index 0000000..1beb12c --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.utils; + +import org.apache.flink.api.common.state.ListState; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** Some utilities related to Flink state. */ +public class StateUtil { + + /** + * Checks whether the provided ListState contains any elements. + * + * @param listState the ListState instance to check + * @return true if the state is not empty; false otherwise + * @throws Exception if an I/O error occurs while reading state + */ + public static boolean listStateNotEmpty(ListState<?> listState) throws Exception { + return listState.get() != null && listState.get().iterator().hasNext(); + } + + /** + * Removes all occurrences of the specified element from the given ListState. + * + * @param listState the ListState instance to modify + * @param element the element to remove + * @return the number of elements removed + * @throws Exception if an I/O error occurs while reading/writing state + */ + public static <T> int removeFromListState(ListState<T> listState, T element) throws Exception { + Iterator<T> listStateIterator = listState.get().iterator(); + if (!listStateIterator.hasNext()) { + return 0; + } + + int removedElementCount = 0; + List<T> remaining = new ArrayList<>(); + while (listStateIterator.hasNext()) { + T next = listStateIterator.next(); + if (next.equals(element)) { + removedElementCount++; + continue; + } + remaining.add(next); + } + listState.clear(); + listState.update(remaining); + return removedElementCount; + } + + /** + * Removes and returns the first element from the ListState. + * + * @param listState the ListState instance to poll from + * @return the first element of the list, or null if the list is empty + * @throws Exception if an I/O error occurs while reading/writing state + */ + public static <T> T pollFromListState(ListState<T> listState) throws Exception { + Iterator<T> listStateIterator = listState.get().iterator(); + if (!listStateIterator.hasNext()) { + return null; + } + + T polled = listStateIterator.next(); + List<T> remaining = new ArrayList<>(); + while (listStateIterator.hasNext()) { + remaining.add(listStateIterator.next()); + } + listState.clear(); + listState.update(remaining); + return polled; + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java index 1617a19..ebfc15f 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java @@ -25,12 +25,12 @@ import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.datastream.KeyedStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.util.CloseableIterator; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; -import java.util.stream.LongStream; import static org.assertj.core.api.Assertions.assertThat; @@ -39,16 +39,23 @@ public class CompileUtilsTest { private static final Long TEST_SEQUENCE_START = 0L; private static final Long TEST_SEQUENCE_END = 100L; + private static final Long TEST_SEQUENCE_REPEAT = 3L; // Agent logic: x -> (x + 1) * 2 private static final AgentPlan TEST_AGENT_PLAN = ActionExecutionOperatorTest.TestAgent.getAgentPlan(false); + private static List<Long> testSequence; + + @BeforeAll + static void setup() { + testSequence = getTestSequence(); + testSequence.sort(Long::compareTo); + } @Test void testJavaNoKeyedStreamConnectToAgent() throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - DataStreamSource<Long> inputStream = - env.fromSequence(TEST_SEQUENCE_START, TEST_SEQUENCE_END); + DataStreamSource<Long> inputStream = env.fromData(testSequence); DataStream<Object> agentOutputStream = CompileUtils.connectToAgent( inputStream, @@ -75,11 +82,10 @@ public class CompileUtilsTest { void testJavaKeyedStreamConnectToAgent() throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); - KeyedStream<Long, Long> keyedInputStream = - env.fromSequence(TEST_SEQUENCE_START, TEST_SEQUENCE_END).keyBy(x -> x); - DataStream<Object> agentOutputStream = + KeyedStream<Long, Long> keyedInputStream = env.fromData(testSequence).keyBy(x -> x); + DataStream<Object> workflowOutputStream = CompileUtils.connectToAgent(keyedInputStream, TEST_AGENT_PLAN); - DataStream<Long> resultStream = agentOutputStream.map(x -> (long) x + 1); + DataStream<Long> resultStream = workflowOutputStream.map(x -> (long) x + 1); List<Long> resultList = new ArrayList<>(); try (CloseableIterator<Long> iterator = resultStream.executeAndCollect()) { @@ -90,12 +96,19 @@ public class CompileUtilsTest { checkResult(resultList); } - private void checkResult(List<Long> resultList) { + private static List<Long> getTestSequence() { + List<Long> testSequence = new ArrayList<>(); + for (int i = 0; i < TEST_SEQUENCE_REPEAT; i++) { + for (long j = TEST_SEQUENCE_START; j <= TEST_SEQUENCE_END; j++) { + testSequence.add(j); + } + } + return testSequence; + } + + private static void checkResult(List<Long> resultList) { List<Long> expectedResultList = - LongStream.rangeClosed(TEST_SEQUENCE_START, TEST_SEQUENCE_END) - .boxed() - .map(x -> (x + 1) * 2 + 1) - .collect(Collectors.toList()); + testSequence.stream().map(x -> (x + 1) * 2 + 1).collect(Collectors.toList()); assertThat(resultList).isEqualTo(expectedResultList); } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 2559bab..90c558c 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -28,6 +28,7 @@ import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.util.ExceptionUtils; import org.junit.jupiter.api.Test; @@ -54,19 +55,95 @@ public class ActionExecutionOperatorTest { (KeySelector<Long, Long>) value -> value, TypeInformation.of(Long.class))) { testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + testHarness.processElement(new StreamRecord<>(0L)); + operator.waitInFlightEventsFinished(); List<StreamRecord<Object>> recordOutput = (List<StreamRecord<Object>>) testHarness.getRecordOutput(); assertThat(recordOutput.size()).isEqualTo(1); assertThat(recordOutput.get(0).getValue()).isEqualTo(2L); testHarness.processElement(new StreamRecord<>(1L)); + operator.waitInFlightEventsFinished(); recordOutput = (List<StreamRecord<Object>>) testHarness.getRecordOutput(); assertThat(recordOutput.size()).isEqualTo(2); assertThat(recordOutput.get(1).getValue()).isEqualTo(4L); } } + @Test + void testSameKeyDataAreProcessedInOrder() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Process input data 1 with key 0 + testHarness.processElement(new StreamRecord<>(0L)); + // Process input data 2, which has the same key (0) + testHarness.processElement(new StreamRecord<>(0L)); + // Since both pieces of data share the same key, we should consolidate them and process + // only input data 1. + // This means we need one mail to execute the action1 action for input data 1. + assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 1); + // After executing this mail, we will have another mail to execute the action2 action + // for input data 1. + assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 1); + // Once the above mails are executed, we should get a single output result from input + // data 1. + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(2L); + + // After the processing of input data 1 is finished, we can proceed to process input + // data 2 and obtain its result. + operator.waitInFlightEventsFinished(); + recordOutput = (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(2); + assertThat(recordOutput.get(1).getValue()).isEqualTo(2L); + } + } + + @Test + void testDifferentKeyDataCanRunConcurrently() throws Exception { + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Process input data 1 with key 0 + testHarness.processElement(new StreamRecord<>(0L)); + // Process input data 2, which has the different key (1) + testHarness.processElement(new StreamRecord<>(1L)); + // Since the two input data items have different keys, they can be processed in + // parallel. + // As a result, we should have two separate mails to execute the action1 for each of + // them. + assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 2); + // After these two mails are executed, there should be another two mails — one for each + // input data item — to execute the corresponding action2. + assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 2); + // Once both action2 operations are completed, we should receive two output data items, + // each corresponding to one of the original inputs. + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(2); + assertThat(recordOutput.get(0).getValue()).isEqualTo(2L); + assertThat(recordOutput.get(1).getValue()).isEqualTo(4L); + } + } + @Test void testMemoryAccessProhibitedOutsideMailboxThread() throws Exception { try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = @@ -75,8 +152,12 @@ public class ActionExecutionOperatorTest { (KeySelector<Long, Long>) value -> value, TypeInformation.of(Long.class))) { testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); - assertThatThrownBy(() -> testHarness.processElement(new StreamRecord<>(0L))) + testHarness.processElement(new StreamRecord<>(0L)); + assertThatThrownBy(() -> operator.waitInFlightEventsFinished()) + .hasCauseInstanceOf(ActionExecutionOperator.ActionTaskExecutionException.class) .rootCause() .hasMessageContaining("Expected to be running on the task mailbox thread"); } @@ -180,4 +261,12 @@ public class ActionExecutionOperatorTest { return null; } } + + private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int expectedSize) + throws Exception { + assertThat(mailbox.size()).isEqualTo(expectedSize); + for (int i = 0; i < expectedSize; i++) { + mailbox.take(TaskMailbox.MIN_PRIORITY).run(); + } + } }