xintongsong commented on code in PR #80: URL: https://github.com/apache/flink-agents/pull/80#discussion_r2247225634
########## python/flink_agents/plan/function.py: ########## @@ -237,6 +237,27 @@ def check_signature(self, *args: Tuple[Any, ...]) -> None: """Check function signature is legal or not.""" +class PythonGeneratorWrapper: + """ + A temporary wrapper class for Python generators to work around a + known issue in PEMJA, where the generator type is incorrectly handled. + + TODO: This wrapper is intended to be a temporary solution. Once PEMJA + version 0.6.0 (or later) fixes the bug related to generator type conversion, + this wrapper should be removed. + """ Review Comment: Better link to the tracking issue. ########## python/flink_agents/plan/function.py: ########## @@ -305,3 +326,31 @@ def get_python_function_cache_keys() -> List[Tuple[str, str]]: List of (module, qualname) tuples representing cached functions. """ return list(_PYTHON_FUNCTION_CACHE.keys()) + + +def call_python_generator(generator_wrapper: PythonGeneratorWrapper) -> (bool, Any): + """ + Invokes the next step of a wrapped Python generator and returns whether + it is done, along with the yielded or returned value. + + Args: + generator_wrapper (PythonGeneratorWrapper): A wrapper object that + contains a `generator` attribute. This attribute should be an instance + of a Python generator. + + Returns: + Tuple[bool, Any]: + - The first element is a boolean flag indicating whether the generator + has finished: + * False: The generator has more values to yield. + * True: The generator has completed. + - The second element is either: + * The value yielded by the generator (when not exhausted), or + * The return value of the generator (when it has finished). + """ + try: + result = next(generator_wrapper.generator) + except StopIteration as e: + return True, e.value + else: + return False, result Review Comment: ```suggestion def call_python_generator(generator_wrapper: PythonGeneratorWrapper) -> (bool, Any): try: result = next(generator_wrapper.generator) return False, result except StopIteration as e: return True, e.value if hasattr(e, 'value') else None except Exception as e: # log and throw error logger.error(f"Error in generator execution: {e}") raise ``` ########## python/flink_agents/runtime/local_runner.py: ########## @@ -203,7 +220,13 @@ def run(self, **data: Dict[str, Any]) -> Any: logger.info( "key: %s, performing action: %s", key, action.name ) - action.exec(event, context) + func_result = action.exec(event, context) + if isinstance(func_result, Generator): + try: + while True: + next(func_result) + except StopIteration: + pass Review Comment: ```suggestion func_result = action.exec(event, context) if isinstance(func_result, Generator): try: for _ in func_result: pass except Exception as e: logger.error(f"Error in async execution: {e}") raise ``` ########## runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java: ########## @@ -40,63 +36,103 @@ public class PythonActionExecutor { "from flink_agents.plan import function\n" + "from flink_agents.runtime import flink_runner_context\n" + "from flink_agents.runtime import python_java_utils"; + + // =========== RUNNER CONTEXT =========== private static final String CREATE_FLINK_RUNNER_CONTEXT = "flink_runner_context.create_flink_runner_context"; + private static final String FLINK_RUNNER_CONTEXT_REF_NAME_PREFIX = "flink_runner_context_"; + private static final AtomicLong FLINK_RUNNER_CONTEXT_REF_ID = new AtomicLong(0); + + // ========== ASYNC THREAD POOL =========== + private static volatile Object pythonAsyncThreadPoolRef; + private static final String CREATE_ASYNC_THREAD_POOL = + "flink_runner_context.create_async_thread_pool"; + private static final String PYTHON_ASYNC_THREAD_POOL_REF_NAME = "python_async_thread_pool"; + private static final AtomicLong PYTHON_ASYNC_THREAD_POOL_REF_ID = new AtomicLong(0); + + // =========== PYTHON GENERATOR =========== + private static final String CALL_PYTHON_GENERATOR = "function.call_python_generator"; + private static final String PYTHON_GENERATOR_VAR_NAME_PREFIX = "python_generator_"; + private static final AtomicLong PYTHON_GENERATOR_VAR_ID = new AtomicLong(0); + + // =========== PYTHON AND JAVA OBJECT CONVERT =========== private static final String CONVERT_TO_PYTHON_OBJECT = "python_java_utils.convert_to_python_object"; private static final String WRAP_TO_INPUT_EVENT = "python_java_utils.wrap_to_input_event"; private static final String GET_OUTPUT_FROM_OUTPUT_EVENT = "python_java_utils.get_output_from_output_event"; - private static final String FLINK_RUNNER_CONTEXT_VAR_NAME = "flink_runner_context"; private final PythonEnvironmentManager environmentManager; private final String agentPlanJson; - private PythonRunnerContextImpl runnerContext; - private PythonInterpreter interpreter; + private String pythonAsyncThreadPoolObjectName; public PythonActionExecutor(PythonEnvironmentManager environmentManager, String agentPlanJson) { this.environmentManager = environmentManager; this.agentPlanJson = agentPlanJson; } - public void open( - MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState, - FlinkAgentsMetricGroupImpl metricGroup) - throws Exception { + public void open() throws Exception { environmentManager.open(); EmbeddedPythonEnvironment env = environmentManager.createEnvironment(); interpreter = env.getInterpreter(); interpreter.exec(PYTHON_IMPORTS); - runnerContext = new PythonRunnerContextImpl(shortTermMemState, metricGroup); - - // TODO: remove the set and get runner context after updating pemja to version 0.5.3 - Object pythonRunnerContextObject = - interpreter.invoke(CREATE_FLINK_RUNNER_CONTEXT, runnerContext, agentPlanJson); - interpreter.set(FLINK_RUNNER_CONTEXT_VAR_NAME, pythonRunnerContextObject); + // TODO: remove the set and get thread pool after updating pemja to version 0.5.3 Review Comment: same here for the issue link ########## runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java: ########## @@ -112,78 +149,151 @@ public void open() throws Exception { metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); - runnerContext = new RunnerContextImpl(shortTermMemState, metricGroup); + // init agent processing related state + actionTasksKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + "actionTasks", TypeInformation.of(ActionTask.class))); + pendingInputEventsKState = + getRuntimeContext() + .getListState( + new ListStateDescriptor<>( + "pendingInputEvents", TypeInformation.of(Event.class))); + // We use UnionList here to ensure that the task can access all keys after parallelism + // modifications. + // Subsequent steps {@link #tryResumeProcessActionTasks} will then filter out keys that do + // not belong to the key range of current task. + currentProcessingKeysOpState = + getOperatorStateBackend() + .getUnionListState( + new ListStateDescriptor<>( + "currentProcessingKeys", TypeInformation.of(Object.class))); // init PythonActionExecutor initPythonActionExecutor(); + + mailboxProcessor = getMailboxProcessor(); + + // Since an operator restart may change the key range it manages due to changes in + // parallelism, + // and {@link tryProcessActionTaskForKey} mails might be lost, + // it is necessary to reprocess all keys to ensure correctness. + tryResumeProcessActionTasks(); } @Override public void processElement(StreamRecord<IN> record) throws Exception { IN input = record.getValue(); LOG.debug("Receive an element {}", input); - // 1. wrap to InputEvent first + // wrap to InputEvent first Event inputEvent = wrapToInputEvent(input); - // 2. execute action - LinkedList<Event> events = new LinkedList<>(); - events.push(inputEvent); - while (!events.isEmpty()) { - Event event = events.pop(); - builtInMetrics.markEventProcessed(); - List<Action> actions = getActionsTriggeredBy(event); - if (actions != null && !actions.isEmpty()) { - for (Action action : actions) { - // TODO: Support multi-action execution for a single event. Example: A Java - // event - // should be processable by both Java and Python actions. - // TODO: Implement asynchronous action execution. - - // execute action and collect output events - String actionName = action.getName(); - LOG.debug("Try execute action {} for event {}.", actionName, event); - List<Event> actionOutputEvents; - if (action.getExec() instanceof JavaFunction) { - runnerContext.setActionName(actionName); - action.getExec().call(event, runnerContext); - actionOutputEvents = runnerContext.drainEvents(); - } else if (action.getExec() instanceof PythonFunction) { - checkState(event instanceof PythonEvent); - actionOutputEvents = - pythonActionExecutor.executePythonFunction( - (PythonFunction) action.getExec(), - (PythonEvent) event, - actionName); - } else { - throw new RuntimeException("Unsupported action type: " + action.getClass()); - } - builtInMetrics.markActionExecuted(actionName); - - for (Event actionOutputEvent : actionOutputEvents) { - if (EventUtil.isOutputEvent(actionOutputEvent)) { - builtInMetrics.markEventProcessed(); - OUT outputData = getOutputFromOutputEvent(actionOutputEvent); - LOG.debug( - "Collect output data {} for input {} in action {}.", - outputData, - input, - action.getName()); - output.collect(reusedStreamRecord.replace(outputData)); - } else { - LOG.debug( - "Collect event {} for event {} in action {}.", - actionOutputEvent, - event, - action.getName()); - events.add(actionOutputEvent); - } - } + if (currentKeyHasMoreActionTask()) { + // If there are already actions being processed for the current key, the newly incoming + // event should be queued and processed later. Therefore, we add it to + // pendingInputEventsState. + pendingInputEventsKState.add(inputEvent); + } else { + // Otherwise, the new event is processed immediately. + processEvent(getCurrentKey(), inputEvent, true); + } + } + + private void processEvent(Object key, Event event, boolean shouldSubmitMail) throws Exception { Review Comment: In terms of readability, I find this `shouldSubmitMail` a bit hard to understand. It's not straightforward that `processEvent` sometimes need to submit the mail to the mailbox and sometimes not. I think the rules for notifying the mailbox executor about pending action tasks in a key is quite simple. We should notify in two cases: 1. When `actionTasksKState` changes from empty to non-empty. That is when the InputEvent is processed, because it is guaranteed that `actionTasksKState` is empty before processing the next InputEvent. 2. When we processed an action task from `actionTasksKState` and there are more remaining. A straightforward approach would be `processEvent` always submit to the mailbox when an input event is processed, and `processActionTaskForKey` always submit to the mailbox when there are remaining action tasks. In this way, we won't need the `shouldSubmitMail` argument here. IIUC, the problem is when we finish all action tasks of the previous input event, and pick up a new pending input event. Currently, we check the remaining action tasks at the end of `processActionTaskForKey`. If we move that check to before processing the a new pending input event, the problem should be solved. ########## python/flink_agents/runtime/flink_runner_context.py: ########## @@ -105,7 +110,34 @@ def get_action_metric_group(self) -> FlinkMetricGroup: """ return FlinkMetricGroup(self._j_runner_context.getActionMetricGroup()) - -def create_flink_runner_context(j_runner_context: Any, agent_plan_json: str) -> FlinkRunnerContext: + @override + def execute_async( + self, + func: Callable[[Any], Any], + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -> Any: + """Asynchronously execute the provided function. Access to memory + is prohibited within the function. + """ + future = self.executor.submit(func, *args, **kwargs) + while not future.done(): + # TODO: Currently, we are using a polling mechanism to check whether + # the future has completed. This approach should be optimized in the + # future by switching to a notification-based model, where the Flink + # operator is notified directly once the future is completed. + yield + return future.result() + + +def create_flink_runner_context( + j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor +) -> FlinkRunnerContext: """Used to create a FlinkRunnerContext Python object in Pemja environment.""" - return FlinkRunnerContext(j_runner_context, agent_plan_json) + return FlinkRunnerContext(j_runner_context, agent_plan_json, executor) + + +def create_async_thread_pool() -> ThreadPoolExecutor: + """Used to create a thread pool to execute asynchronous + code block in action.""" + return ThreadPoolExecutor(max_workers=os.cpu_count() * 2) Review Comment: How do we clean-up the thread pool? ########## plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java: ########## @@ -71,7 +71,10 @@ public Class<?>[] getParameterTypes() { return parameterTypes; } - public Method getMethod() { + public Method getMethod() throws ClassNotFoundException, NoSuchMethodException { + if (method == null) { + this.method = Class.forName(qualName).getMethod(methodName, parameterTypes); + } Review Comment: Why do we need this lazy initialization? ########## runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java: ########## @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.operator; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.runtime.context.RunnerContextImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * This class represents a task related to the execution of an action in {@link + * ActionExecutionOperator}. + * + * <p>An action is split into multiple code blocks, and each code block is represented by an {@code + * ActionTask}. You can call {@link #invoke()} to execute a code block and obtain invoke result + * {@link ActionTaskResult}. If the action contains additional code blocks, you can obtain the next + * {@code ActionTask} via {@link ActionTaskResult#getGeneratedActionTask()} and continue executing + * it. + */ +public abstract class ActionTask { + + protected static final Logger LOG = LoggerFactory.getLogger(ActionTask.class); + + protected final Object key; + protected final Event event; + protected final Action action; + /** + * Since RunnerContextImpl contains references to the Operator and state, it should not be + * serialized and included in the state with ActionTask. Instead, we should check if a valid + * RunnerContext exists before each ActionTask invocation and create a new one if necessary. + */ + protected transient RunnerContextImpl runnerContext; + + public ActionTask(Object key, Event event, Action action) { + this.key = key; + this.event = event; + this.action = action; + } + + public RunnerContextImpl getRunnerContext() { + return runnerContext; + } + + public void setRunnerContext(RunnerContextImpl runnerContext) { + this.runnerContext = runnerContext; + } + + public Object getKey() { + return key; + } + + /** Invokes the action task. */ + public abstract ActionTaskResult invoke() throws Exception; + + public class ActionTaskResult { + private final boolean finished; + private final List<Event> outputEvents; + private final ActionTask generatedActionTask; + + public ActionTaskResult( + boolean finished, List<Event> outputEvents, ActionTask generatedActionTask) { + this.finished = finished; + this.outputEvents = outputEvents; + this.generatedActionTask = generatedActionTask; + } + + public boolean isFinished() { + return finished; + } + + public List<Event> getOutputEvents() { + return outputEvents; + } + + public ActionTask getGeneratedActionTask() { Review Comment: Maybe return an `Optional<ActionTask>` and force the checking and null-handling at the caller side? ########## python/flink_agents/runtime/local_runner.py: ########## @@ -128,6 +128,23 @@ def get_action_metric_group(self) -> MetricGroup: err_msg = "Metric mechanism is not supported for local agent execution yet." raise NotImplementedError(err_msg) + def execute_async( + self, + func: Callable[[Any], Any], + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -> Any: + """Asynchronously execute the provided function. Access to memory + is prohibited within the function. + """ + logger.warning( + "Local runner does not support asynchronous execution; falling back to synchronous execution." + ) + func_result = func(*args, **kwargs) + yield func_result + return func_result Review Comment: Should not need this `return` statement. ########## runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java: ########## @@ -240,4 +362,70 @@ private List<Action> getActionsTriggeredBy(Event event) { return agentPlan.getActionsTriggeredBy(event.getClass().getName()); } } + + private MailboxProcessor getMailboxProcessor() throws Exception { + Field field = MailboxExecutorImpl.class.getDeclaredField("mailboxProcessor"); + field.setAccessible(true); + return (MailboxProcessor) field.get(mailboxExecutor); + } + + private void checkMailboxThread() { + checkState( + mailboxProcessor.isMailboxThread(), + "Expected to be running on the task mailbox thread, but was not."); + } + + private ActionTask createActionTask(Object key, Action action, Event event) { + if (action.getExec() instanceof JavaFunction) { + return new JavaActionTask(key, event, action); + } else if (action.getExec() instanceof PythonFunction) { + return new PythonActionTask(key, event, action, pythonActionExecutor); + } else { + throw new IllegalStateException( + "Unsupported action type: " + action.getExec().getClass()); + } + } + + private void createAndSetRunnerContext(ActionTask actionTask) { + if (actionTask.getRunnerContext() != null) { + return; + } + + RunnerContextImpl runnerContext; + if (actionTask.action.getExec() instanceof JavaFunction) { + runnerContext = + new RunnerContextImpl(shortTermMemState, metricGroup, this::checkMailboxThread); + } else if (actionTask.action.getExec() instanceof PythonFunction) { + runnerContext = + new PythonRunnerContextImpl( + shortTermMemState, metricGroup, this::checkMailboxThread); + } else { + throw new IllegalStateException( + "Unsupported action type: " + actionTask.action.getExec().getClass()); + } + + runnerContext.setActionName(actionTask.action.getName()); + actionTask.setRunnerContext(runnerContext); + } + + private boolean currentKeyHasMoreActionTask() throws Exception { + return listStateNotEmpty(actionTasksKState); + } + + private void tryResumeProcessActionTasks() throws Exception { + Iterable<Object> keys = currentProcessingKeysOpState.get(); + if (keys != null) { + for (Object key : keys) { + mailboxExecutor.submit( + () -> tryProcessActionTaskForKey(key), "process action task"); + } + } + } Review Comment: This is smart! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org