This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit a50137817a10c2bf192cfa6e5a97f613fd640837
Author: Xu Huang <zuosi...@alibaba-inc.com>
AuthorDate: Sat Jul 26 10:08:29 2025 +0800

    [runtime] Implement async execution in ActionExecutionOperator for Python 
actions
---
 .../org/apache/flink/agents/plan/JavaFunction.java |   9 +-
 python/flink_agents/examples/my_agent.py           |  12 +-
 python/flink_agents/plan/function.py               |  81 +++++-
 .../create_python_agent_plan_from_json.py          |   2 +-
 python/flink_agents/plan/tests/test_function.py    |   5 +-
 .../flink_agents/runtime/flink_runner_context.py   |  37 ++-
 .../runtime/operator/ActionExecutionOperator.java  | 301 ++++++++++++++++-----
 .../flink/agents/runtime/operator/ActionTask.java  | 114 ++++++++
 .../agents/runtime/operator/JavaActionTask.java    |  50 ++++
 .../runtime/python/operator/PythonActionTask.java  |  74 +++++
 .../python/operator/PythonGeneratorActionTask.java |  50 ++++
 .../runtime/python/utils/PythonActionExecutor.java | 119 ++++++--
 .../flink/agents/runtime/utils/StateUtil.java      |  91 +++++++
 .../flink/agents/runtime/CompileUtilsTest.java     |  37 ++-
 .../operator/ActionExecutionOperatorTest.java      |  91 ++++++-
 15 files changed, 944 insertions(+), 129 deletions(-)

diff --git a/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java 
b/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java
index 580004c..37f7f37 100644
--- a/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java
+++ b/plan/src/main/java/org/apache/flink/agents/plan/JavaFunction.java
@@ -38,7 +38,7 @@ public class JavaFunction implements Function {
     @JsonProperty(FIELD_NAME_PARAMETER_TYPES)
     private final Class<?>[] parameterTypes;
 
-    @JsonIgnore private final Method method;
+    @JsonIgnore private transient Method method;
 
     public JavaFunction(
             @JsonProperty(FIELD_NAME_QUAL_NAME) String qualName,
@@ -71,7 +71,10 @@ public class JavaFunction implements Function {
         return parameterTypes;
     }
 
-    public Method getMethod() {
+    public Method getMethod() throws ClassNotFoundException, 
NoSuchMethodException {
+        if (method == null) {
+            this.method = Class.forName(qualName).getMethod(methodName, 
parameterTypes);
+        }
         return method;
     }
 
@@ -93,7 +96,7 @@ public class JavaFunction implements Function {
 
     @Override
     public Object call(Object... args) throws Exception {
-        return method.invoke(null, args);
+        return getMethod().invoke(null, args);
     }
 
     @Override
diff --git a/python/flink_agents/examples/my_agent.py 
b/python/flink_agents/examples/my_agent.py
index b3947f3..b810d64 100644
--- a/python/flink_agents/examples/my_agent.py
+++ b/python/flink_agents/examples/my_agent.py
@@ -16,6 +16,8 @@
 # limitations under the License.
 
#################################################################################
 import copy
+import random
+import time
 from typing import Any, Optional
 
 from pydantic import BaseModel
@@ -60,6 +62,12 @@ class DataStreamAgent(Agent):
     @action(InputEvent)
     @staticmethod
     def first_action(event: Event, ctx: RunnerContext):  # noqa D102
+        def log_to_stdout(input: Any, total: int) -> bool:
+            # Simulating asynchronous time consumption
+            time.sleep(random.random())
+            print(f"[log_to_stdout] Logging input={input}, total reviews 
now={total}")
+            return True
+
         input = event.input
 
         stm = ctx.get_short_term_memory()
@@ -71,8 +79,10 @@ class DataStreamAgent(Agent):
         total += 1
         status.set("total_reviews", total)
 
+        log_success = yield from ctx.execute_async(log_to_stdout, input, total)
+
         content = copy.deepcopy(input)
-        content.review += " first action"
+        content.review += " first action, log success=" + str(log_success) + 
","
         ctx.send_event(MyEvent(value=content))
 
     @action(MyEvent)
diff --git a/python/flink_agents/plan/function.py 
b/python/flink_agents/plan/function.py
index ebb2df8..032eec6 100644
--- a/python/flink_agents/plan/function.py
+++ b/python/flink_agents/plan/function.py
@@ -16,10 +16,11 @@
 # limitations under the License.
 
#################################################################################
 
+import logging
 import importlib
 import inspect
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Dict, List, Tuple, get_type_hints
+from typing import Any, Callable, Dict, List, Tuple, get_type_hints, Generator
 
 from pydantic import BaseModel
 
@@ -28,6 +29,11 @@ from flink_agents.plan.utils import check_type_match
 # Global cache for PythonFunction instances to avoid repeated creation
 _PYTHON_FUNCTION_CACHE: Dict[Tuple[str, str], "PythonFunction"] = {}
 
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
 
 def _is_function_cacheable(func: Callable) -> bool:
     """Check if a function is safe to cache.
@@ -237,6 +243,28 @@ class JavaFunction(Function):
         """Check function signature is legal or not."""
 
 
+class PythonGeneratorWrapper:
+    """
+    A temporary wrapper class for Python generators to work around a
+    known issue in PEMJA, where the generator type is incorrectly handled.
+
+    TODO: This wrapper is intended to be a temporary solution. Once PEMJA
+    version 0.5.5 (or later) fixes the bug related to generator type 
conversion,
+    this wrapper should be removed. For more details, please refer to
+    https://github.com/apache/flink-agents/issues/83.
+    """
+
+    def __init__(self, generator: Generator) -> None:
+        """Initialize a PythonGeneratorWrapper. """
+        self.generator = generator
+
+    def __str__(self) -> str:
+        return "PythonGeneratorWrapper, generator=" + str(self.generator)
+
+    def __next__(self) -> Any:
+        return next(self.generator)
+
+
 def call_python_function(module: str, qualname: str, func_args: Tuple[Any, 
...]) -> Any:
     """Used to call a Python function in the Pemja environment.
 
@@ -260,19 +288,19 @@ def call_python_function(module: str, qualname: str, 
func_args: Tuple[Any, ...])
     """
     cache_key = (module, qualname)
 
+    python_func = None
+
     if cache_key not in _PYTHON_FUNCTION_CACHE:
         python_func = PythonFunction(module=module, qualname=qualname)
-        try:
-            if python_func.is_cacheable():
-                _PYTHON_FUNCTION_CACHE[cache_key] = python_func
-            else:
-                return python_func(*func_args)
-        except Exception:
-            return python_func(*func_args)
+        if python_func.is_cacheable():
+            _PYTHON_FUNCTION_CACHE[cache_key] = python_func
+    else:
+        python_func = _PYTHON_FUNCTION_CACHE[cache_key]
 
-    # Use cached instance
-    func = _PYTHON_FUNCTION_CACHE[cache_key]
-    return func(*func_args)
+    func_result = python_func(*func_args)
+    if isinstance(func_result, Generator):
+        return PythonGeneratorWrapper(func_result)
+    return func_result
 
 
 def clear_python_function_cache() -> None:
@@ -305,3 +333,34 @@ def get_python_function_cache_keys() -> List[Tuple[str, 
str]]:
         List of (module, qualname) tuples representing cached functions.
     """
     return list(_PYTHON_FUNCTION_CACHE.keys())
+
+
+def call_python_generator(generator_wrapper: PythonGeneratorWrapper) -> (bool, 
Any):
+    """
+    Invokes the next step of a wrapped Python generator and returns whether
+    it is done, along with the yielded or returned value.
+
+    Args:
+        generator_wrapper (PythonGeneratorWrapper): A wrapper object that
+        contains a `generator` attribute. This attribute should be an instance
+        of a Python generator.
+
+    Returns:
+        Tuple[bool, Any]:
+            - The first element is a boolean flag indicating whether the 
generator
+            has finished:
+                * False: The generator has more values to yield.
+                * True: The generator has completed.
+            - The second element is either:
+                * The value yielded by the generator (when not exhausted), or
+                * The return value of the generator (when it has finished).
+    """
+    try:
+        result = next(generator_wrapper.generator)
+    except StopIteration as e:
+        return True, e.value if hasattr(e, 'value') else None
+    except Exception:
+        logger.exception("Error in generator execution")
+        raise
+    else:
+        return False, result
diff --git 
a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
 
b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
index 828ac7c..6ff356d 100644
--- 
a/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
+++ 
b/python/flink_agents/plan/tests/compatibility/create_python_agent_plan_from_json.py
@@ -75,7 +75,7 @@ if __name__ == "__main__":
     assert len(actions_by_event) == 2
 
     assert input_event in actions_by_event
-    assert actions_by_event[input_event] == ["firstAction", "secondAction"]
+    assert sorted(actions_by_event[input_event]) == ["firstAction", 
"secondAction"]
 
     assert my_event in actions_by_event
     assert actions_by_event[my_event] == ["secondAction"]
diff --git a/python/flink_agents/plan/tests/test_function.py 
b/python/flink_agents/plan/tests/test_function.py
index 945401d..f79e7bb 100644
--- a/python/flink_agents/plan/tests/test_function.py
+++ b/python/flink_agents/plan/tests/test_function.py
@@ -30,7 +30,7 @@ from flink_agents.plan.function import (
     call_python_function,
     clear_python_function_cache,
     get_python_function_cache_keys,
-    get_python_function_cache_size,
+    get_python_function_cache_size, PythonGeneratorWrapper,
 )
 
 if TYPE_CHECKING:
@@ -314,8 +314,9 @@ def test_selective_caching_generator_functions() -> None:
     result = call_python_function(
         "flink_agents.plan.tests.test_function", "generator_function", (3,)
     )
+    assert isinstance(result, PythonGeneratorWrapper)
     # Convert generator to list for testing
-    result_list = list(result)
+    result_list = list(result.generator)
     assert result_list == [0, 1, 2]
 
     # Should not be cached
diff --git a/python/flink_agents/runtime/flink_runner_context.py 
b/python/flink_agents/runtime/flink_runner_context.py
index 99745c3..5f3bb90 100644
--- a/python/flink_agents/runtime/flink_runner_context.py
+++ b/python/flink_agents/runtime/flink_runner_context.py
@@ -26,6 +26,8 @@ from flink_agents.api.runner_context import RunnerContext
 from flink_agents.plan.agent_plan import AgentPlan
 from flink_agents.runtime.flink_memory_object import FlinkMemoryObject
 from flink_agents.runtime.flink_metric_group import FlinkMetricGroup
+from concurrent.futures import ThreadPoolExecutor
+import os
 
 
 class FlinkRunnerContext(RunnerContext):
@@ -36,7 +38,9 @@ class FlinkRunnerContext(RunnerContext):
 
     __agent_plan: AgentPlan
 
-    def __init__(self, j_runner_context: Any, agent_plan_json: str) -> None:
+    def __init__(
+        self, j_runner_context: Any, agent_plan_json: str, executor: 
ThreadPoolExecutor
+    ) -> None:
         """Initialize a flink runner context with the given java runner 
context.
 
         Parameters
@@ -46,6 +50,7 @@ class FlinkRunnerContext(RunnerContext):
         """
         self._j_runner_context = j_runner_context
         self.__agent_plan = AgentPlan.model_validate_json(agent_plan_json)
+        self.executor = executor
 
     @override
     def send_event(self, event: Event) -> None:
@@ -115,9 +120,29 @@ class FlinkRunnerContext(RunnerContext):
         """Asynchronously execute the provided function. Access to memory
          is prohibited within the function.
         """
-        # TODO: Implement in a future commit.
-        pass
-
-def create_flink_runner_context(j_runner_context: Any, agent_plan_json: str) 
-> FlinkRunnerContext:
+        future = self.executor.submit(func, *args, **kwargs)
+        while not future.done():
+            # TODO: Currently, we are using a polling mechanism to check 
whether
+            #  the future has completed. This approach should be optimized in 
the
+            #  future by switching to a notification-based model, where the 
Flink
+            #  operator is notified directly once the future is completed.
+            yield
+        return future.result()
+
+
+def create_flink_runner_context(
+    j_runner_context: Any, agent_plan_json: str, executor: ThreadPoolExecutor
+) -> FlinkRunnerContext:
     """Used to create a FlinkRunnerContext Python object in Pemja 
environment."""
-    return FlinkRunnerContext(j_runner_context, agent_plan_json)
+    return FlinkRunnerContext(j_runner_context, agent_plan_json, executor)
+
+
+def create_async_thread_pool() -> ThreadPoolExecutor:
+    """Used to create a thread pool to execute asynchronous
+    code block in action."""
+    return ThreadPoolExecutor(max_workers=os.cpu_count() * 2)
+
+
+def close_async_thread_pool(executor: ThreadPoolExecutor) -> None:
+    """Used to close the thread pool."""
+    executor.shutdown()
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
index 9eef4ff..85544cd 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java
@@ -29,16 +29,22 @@ import 
org.apache.flink.agents.runtime.env.PythonEnvironmentManager;
 import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
 import org.apache.flink.agents.runtime.metrics.BuiltInMetrics;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
+import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
+import org.apache.flink.agents.runtime.python.operator.PythonActionTask;
 import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
 import org.apache.flink.agents.runtime.utils.EventUtil;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.api.common.state.*;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.python.env.PythonDependencyInfo;
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -46,14 +52,19 @@ import 
org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl;
 import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxProcessor;
 import org.apache.flink.types.Row;
+import org.apache.flink.util.ExceptionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.lang.reflect.Field;
 import java.util.HashMap;
-import java.util.LinkedList;
 import java.util.List;
+import java.util.Optional;
 
+import static 
org.apache.flink.agents.runtime.utils.StateUtil.listStateNotEmpty;
+import static 
org.apache.flink.agents.runtime.utils.StateUtil.pollFromListState;
+import static 
org.apache.flink.agents.runtime.utils.StateUtil.removeFromListState;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -67,7 +78,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  * and the resulting output event is collected for further processing.
  */
 public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT>
-        implements OneInputStreamOperator<IN, OUT> {
+        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
     private static final long serialVersionUID = 1L;
 
@@ -81,9 +92,6 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient MapState<String, MemoryObjectImpl.MemoryItem> 
shortTermMemState;
 
-    // RunnerContext for Java actions
-    private transient RunnerContextImpl runnerContext;
-
     // PythonActionExecutor for Python actions
     private transient PythonActionExecutor pythonActionExecutor;
 
@@ -91,7 +99,7 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
 
     private transient BuiltInMetrics builtInMetrics;
 
-    private transient MailboxExecutor mailboxExecutor;
+    private final transient MailboxExecutor mailboxExecutor;
 
     // We need to check whether the current thread is the mailbox thread using 
the mailbox
     // processor.
@@ -100,6 +108,19 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
     // to obtain the MailboxProcessor instance and make the determination.
     private transient MailboxProcessor mailboxProcessor;
 
+    // An action will be split into one or more ActionTask objects. We use a 
state to store the
+    // pending ActionTasks that are waiting to be executed.
+    private transient ListState<ActionTask> actionTasksKState;
+
+    // To avoid processing different InputEvents with the same key, we use a 
state to store pending
+    // InputEvents that are waiting to be processed.
+    private transient ListState<Event> pendingInputEventsKState;
+
+    // An operator state is used to track the currently processing keys. This 
is useful when
+    // receiving an EndOfInput signal, as we need to wait until all related 
events are fully
+    // processed.
+    private transient ListState<Object> currentProcessingKeysOpState;
+
     public ActionExecutionOperator(
             AgentPlan agentPlan,
             Boolean inputIsJava,
@@ -129,13 +150,37 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup());
         builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan);
 
-        runnerContext =
-                new RunnerContextImpl(shortTermMemState, metricGroup, 
this::checkMailboxThread);
+        // init agent processing related state
+        actionTasksKState =
+                getRuntimeContext()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "actionTasks", 
TypeInformation.of(ActionTask.class)));
+        pendingInputEventsKState =
+                getRuntimeContext()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "pendingInputEvents", 
TypeInformation.of(Event.class)));
+        // We use UnionList here to ensure that the task can access all keys 
after parallelism
+        // modifications.
+        // Subsequent steps {@link #tryResumeProcessActionTasks} will then 
filter out keys that do
+        // not belong to the key range of current task.
+        currentProcessingKeysOpState =
+                getOperatorStateBackend()
+                        .getUnionListState(
+                                new ListStateDescriptor<>(
+                                        "currentProcessingKeys", 
TypeInformation.of(Object.class)));
 
         // init PythonActionExecutor
         initPythonActionExecutor();
 
         mailboxProcessor = getMailboxProcessor();
+
+        // Since an operator restart may change the key range it manages due 
to changes in
+        // parallelism,
+        // and {@link tryProcessActionTaskForKey} mails might be lost,
+        // it is necessary to reprocess all keys to ensure correctness.
+        tryResumeProcessActionTasks();
     }
 
     @Override
@@ -143,65 +188,120 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
         IN input = record.getValue();
         LOG.debug("Receive an element {}", input);
 
-        // 1. wrap to InputEvent first
+        // wrap to InputEvent first
         Event inputEvent = wrapToInputEvent(input);
 
-        // 2. execute action
-        LinkedList<Event> events = new LinkedList<>();
-        events.push(inputEvent);
-        while (!events.isEmpty()) {
-            Event event = events.pop();
-            builtInMetrics.markEventProcessed();
-            List<Action> actions = getActionsTriggeredBy(event);
-            if (actions != null && !actions.isEmpty()) {
-                for (Action action : actions) {
-                    // TODO: Support multi-action execution for a single 
event. Example: A Java
-                    // event
-                    // should be processable by both Java and Python actions.
-                    // TODO: Implement asynchronous action execution.
-
-                    // execute action and collect output events
-                    String actionName = action.getName();
-                    LOG.debug("Try execute action {} for event {}.", 
actionName, event);
-                    List<Event> actionOutputEvents;
-                    if (action.getExec() instanceof JavaFunction) {
-                        runnerContext.setActionName(actionName);
-                        action.getExec().call(event, runnerContext);
-                        actionOutputEvents = runnerContext.drainEvents();
-                    } else if (action.getExec() instanceof PythonFunction) {
-                        checkState(event instanceof PythonEvent);
-                        actionOutputEvents =
-                                pythonActionExecutor.executePythonFunction(
-                                        (PythonFunction) action.getExec(),
-                                        (PythonEvent) event,
-                                        actionName);
-                    } else {
-                        throw new RuntimeException("Unsupported action type: " 
+ action.getClass());
-                    }
-                    builtInMetrics.markActionExecuted(actionName);
-
-                    for (Event actionOutputEvent : actionOutputEvents) {
-                        if (EventUtil.isOutputEvent(actionOutputEvent)) {
-                            builtInMetrics.markEventProcessed();
-                            OUT outputData = 
getOutputFromOutputEvent(actionOutputEvent);
-                            LOG.debug(
-                                    "Collect output data {} for input {} in 
action {}.",
-                                    outputData,
-                                    input,
-                                    action.getName());
-                            
output.collect(reusedStreamRecord.replace(outputData));
-                        } else {
-                            LOG.debug(
-                                    "Collect event {} for event {} in action 
{}.",
-                                    actionOutputEvent,
-                                    event,
-                                    action.getName());
-                            events.add(actionOutputEvent);
-                        }
-                    }
+        if (currentKeyHasMoreActionTask()) {
+            // If there are already actions being processed for the current 
key, the newly incoming
+            // event should be queued and processed later. Therefore, we add 
it to
+            // pendingInputEventsState.
+            pendingInputEventsKState.add(inputEvent);
+        } else {
+            // Otherwise, the new event is processed immediately.
+            processEvent(getCurrentKey(), inputEvent);
+        }
+    }
+
+    /**
+     * Processes an incoming event for the given key and may submit a new mail
+     * `tryProcessActionTaskForKey` to continue processing.
+     */
+    private void processEvent(Object key, Event event) throws Exception {
+        boolean isInputEvent = EventUtil.isInputEvent(event);
+        builtInMetrics.markEventProcessed();
+        if (EventUtil.isOutputEvent(event)) {
+            // If the event is an OutputEvent, we send it downstream.
+            OUT outputData = getOutputFromOutputEvent(event);
+            output.collect(reusedStreamRecord.replace(outputData));
+        } else {
+            if (isInputEvent) {
+                // If the event is an InputEvent, we mark that the key is 
currently being processed.
+                currentProcessingKeysOpState.add(key);
+            }
+            // We then obtain the triggered action and add ActionTasks to the 
waiting processing
+            // queue.
+            List<Action> triggerActions = getActionsTriggeredBy(event);
+            if (triggerActions != null && !triggerActions.isEmpty()) {
+                for (Action triggerAction : triggerActions) {
+                    actionTasksKState.add(createActionTask(key, triggerAction, 
event));
                 }
             }
         }
+
+        if (isInputEvent) {
+            // If the event is an InputEvent, we submit a new mail to try 
processing the actions.
+            mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), 
"process action task");
+        }
+    }
+
+    private void tryProcessActionTaskForKey(Object key) {
+        try {
+            processActionTaskForKey(key);
+        } catch (Exception e) {
+            mailboxExecutor.execute(
+                    () ->
+                            ExceptionUtils.rethrow(
+                                    new ActionTaskExecutionException(
+                                            "Failed to execute action task", 
e)),
+                    "throw exception in mailbox");
+        }
+    }
+
+    private void processActionTaskForKey(Object key) throws Exception {
+        // 1. Get an action task for the key.
+        setCurrentKey(key);
+        ActionTask actionTask = pollFromListState(actionTasksKState);
+        if (actionTask == null) {
+            int removedCount = 
removeFromListState(currentProcessingKeysOpState, key);
+            checkState(
+                    removedCount == 1,
+                    "Current processing key count for key "
+                            + key
+                            + " should be 1, but got "
+                            + removedCount);
+            return;
+        }
+
+        // 2. Invoke the action task.
+        createAndSetRunnerContext(actionTask);
+        ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke();
+        for (Event actionOutputEvent : actionTaskResult.getOutputEvents()) {
+            processEvent(key, actionOutputEvent);
+        }
+
+        boolean currentInputEventFinished = false;
+        if (actionTaskResult.isFinished()) {
+            builtInMetrics.markActionExecuted(actionTask.action.getName());
+            currentInputEventFinished = !currentKeyHasMoreActionTask();
+        } else {
+            // If the action task not finished, we should get a new action 
task to execute continue.
+            Optional<ActionTask> generatedActionTaskOpt = 
actionTaskResult.getGeneratedActionTask();
+            checkNotNull(
+                    generatedActionTaskOpt.isPresent(),
+                    "ActionTask not finished, but the generated action task is 
null.");
+            actionTasksKState.add(generatedActionTaskOpt.get());
+        }
+
+        // 3. Process the next InputEvent or next action task
+        if (currentInputEventFinished) {
+            // Once all sub-events and actions related to the current 
InputEvent are completed,
+            // we can proceed to process the next InputEvent.
+            int removedCount = 
removeFromListState(currentProcessingKeysOpState, key);
+            checkState(
+                    removedCount == 1,
+                    "Current processing key count for key "
+                            + key
+                            + " should be 1, but got "
+                            + removedCount);
+            Event pendingInputEvent = 
pollFromListState(pendingInputEventsKState);
+            if (pendingInputEvent != null) {
+                processEvent(key, pendingInputEvent);
+            }
+        } else if (currentKeyHasMoreActionTask()) {
+            // If the current key has additional action tasks remaining, we 
should submit a new mail
+            // to continue processing them.
+            mailboxExecutor.submit(() -> tryProcessActionTaskForKey(key), 
"process action task");
+        }
     }
 
     private void initPythonActionExecutor() throws Exception {
@@ -227,10 +327,31 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                     new PythonActionExecutor(
                             pythonEnvironmentManager,
                             new ObjectMapper().writeValueAsString(agentPlan));
-            pythonActionExecutor.open(shortTermMemState, metricGroup, 
this::checkMailboxThread);
+            pythonActionExecutor.open();
+        }
+    }
+
+    @Override
+    public void endInput() throws Exception {
+        waitInFlightEventsFinished();
+    }
+
+    @VisibleForTesting
+    public void waitInFlightEventsFinished() throws Exception {
+        while (listStateNotEmpty(currentProcessingKeysOpState)) {
+            mailboxExecutor.yield();
         }
     }
 
+    @Override
+    public void close() throws Exception {
+        if (pythonActionExecutor != null) {
+            pythonActionExecutor.close();
+        }
+
+        super.close();
+    }
+
     private Event wrapToInputEvent(IN input) {
         if (inputIsJava) {
             return new InputEvent(input);
@@ -272,4 +393,58 @@ public class ActionExecutionOperator<IN, OUT> extends 
AbstractStreamOperator<OUT
                 mailboxProcessor.isMailboxThread(),
                 "Expected to be running on the task mailbox thread, but was 
not.");
     }
+
+    private ActionTask createActionTask(Object key, Action action, Event 
event) {
+        if (action.getExec() instanceof JavaFunction) {
+            return new JavaActionTask(key, event, action);
+        } else if (action.getExec() instanceof PythonFunction) {
+            return new PythonActionTask(key, event, action, 
pythonActionExecutor);
+        } else {
+            throw new IllegalStateException(
+                    "Unsupported action type: " + action.getExec().getClass());
+        }
+    }
+
+    private void createAndSetRunnerContext(ActionTask actionTask) {
+        if (actionTask.getRunnerContext() != null) {
+            return;
+        }
+
+        RunnerContextImpl runnerContext;
+        if (actionTask.action.getExec() instanceof JavaFunction) {
+            runnerContext =
+                    new RunnerContextImpl(shortTermMemState, metricGroup, 
this::checkMailboxThread);
+        } else if (actionTask.action.getExec() instanceof PythonFunction) {
+            runnerContext =
+                    new PythonRunnerContextImpl(
+                            shortTermMemState, metricGroup, 
this::checkMailboxThread);
+        } else {
+            throw new IllegalStateException(
+                    "Unsupported action type: " + 
actionTask.action.getExec().getClass());
+        }
+
+        runnerContext.setActionName(actionTask.action.getName());
+        actionTask.setRunnerContext(runnerContext);
+    }
+
+    private boolean currentKeyHasMoreActionTask() throws Exception {
+        return listStateNotEmpty(actionTasksKState);
+    }
+
+    private void tryResumeProcessActionTasks() throws Exception {
+        Iterable<Object> keys = currentProcessingKeysOpState.get();
+        if (keys != null) {
+            for (Object key : keys) {
+                mailboxExecutor.submit(
+                        () -> tryProcessActionTaskForKey(key), "process action 
task");
+            }
+        }
+    }
+
+    /** Failed to execute Action task. */
+    public static class ActionTaskExecutionException extends Exception {
+        public ActionTaskExecutionException(String message, Throwable cause) {
+            super(message, cause);
+        }
+    }
 }
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
new file mode 100644
index 0000000..60ecab0
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionTask.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.Action;
+import org.apache.flink.agents.runtime.context.RunnerContextImpl;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * This class represents a task related to the execution of an action in {@link
+ * ActionExecutionOperator}.
+ *
+ * <p>An action is split into multiple code blocks, and each code block is 
represented by an {@code
+ * ActionTask}. You can call {@link #invoke()} to execute a code block and 
obtain invoke result
+ * {@link ActionTaskResult}. If the action contains additional code blocks, 
you can obtain the next
+ * {@code ActionTask} via {@link ActionTaskResult#getGeneratedActionTask()} 
and continue executing
+ * it.
+ */
+public abstract class ActionTask {
+
+    protected static final Logger LOG = 
LoggerFactory.getLogger(ActionTask.class);
+
+    protected final Object key;
+    protected final Event event;
+    protected final Action action;
+    /**
+     * Since RunnerContextImpl contains references to the Operator and state, 
it should not be
+     * serialized and included in the state with ActionTask. Instead, we 
should check if a valid
+     * RunnerContext exists before each ActionTask invocation and create a new 
one if necessary.
+     */
+    protected transient RunnerContextImpl runnerContext;
+
+    public ActionTask(Object key, Event event, Action action) {
+        this.key = key;
+        this.event = event;
+        this.action = action;
+    }
+
+    public RunnerContextImpl getRunnerContext() {
+        return runnerContext;
+    }
+
+    public void setRunnerContext(RunnerContextImpl runnerContext) {
+        this.runnerContext = runnerContext;
+    }
+
+    public Object getKey() {
+        return key;
+    }
+
+    /** Invokes the action task. */
+    public abstract ActionTaskResult invoke() throws Exception;
+
+    public class ActionTaskResult {
+        private final boolean finished;
+        private final List<Event> outputEvents;
+        private final Optional<ActionTask> generatedActionTaskOpt;
+
+        public ActionTaskResult(
+                boolean finished,
+                List<Event> outputEvents,
+                @Nullable ActionTask generatedActionTask) {
+            this.finished = finished;
+            this.outputEvents = outputEvents;
+            this.generatedActionTaskOpt = 
Optional.ofNullable(generatedActionTask);
+        }
+
+        public boolean isFinished() {
+            return finished;
+        }
+
+        public List<Event> getOutputEvents() {
+            return outputEvents;
+        }
+
+        public Optional<ActionTask> getGeneratedActionTask() {
+            return generatedActionTaskOpt;
+        }
+
+        @Override
+        public String toString() {
+            return "ActionTaskResult{"
+                    + "finished="
+                    + finished
+                    + ", outputEvents="
+                    + outputEvents
+                    + ", generatedActionTaskOpt="
+                    + generatedActionTaskOpt
+                    + '}';
+        }
+    }
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
new file mode 100644
index 0000000..71d04b8
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/JavaActionTask.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.Action;
+import org.apache.flink.agents.plan.JavaFunction;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * A special {@link ActionTask} designed to execute a Java action task.
+ *
+ * <p>Note that Java action currently do not support asynchronous execution. 
As a result, a Java
+ * action task will be invoked only once.
+ */
+public class JavaActionTask extends ActionTask {
+
+    public JavaActionTask(Object key, Event event, Action action) {
+        super(key, event, action);
+        checkState(action.getExec() instanceof JavaFunction);
+    }
+
+    @Override
+    public ActionTaskResult invoke() throws Exception {
+        LOG.debug(
+                "Try execute java action {} for event {} with key {}.",
+                action.getName(),
+                event,
+                key);
+        runnerContext.checkNoPendingEvents();
+        action.getExec().call(event, runnerContext);
+        return new ActionTaskResult(true, runnerContext.drainEvents(), null);
+    }
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
new file mode 100644
index 0000000..53fb93b
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonActionTask.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.python.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.Action;
+import org.apache.flink.agents.plan.PythonFunction;
+import org.apache.flink.agents.runtime.operator.ActionTask;
+import org.apache.flink.agents.runtime.python.event.PythonEvent;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * A special {@link ActionTask} designed to execute a Python action task.
+ *
+ * <p>During asynchronous execution in Python, the {@link PythonActionTask} 
can produce a {@link
+ * PythonGeneratorActionTask} to represent the subsequent code block when 
needed.
+ */
+public class PythonActionTask extends ActionTask {
+
+    protected final PythonActionExecutor pythonActionExecutor;
+
+    public PythonActionTask(
+            Object key, Event event, Action action, PythonActionExecutor 
pythonActionExecutor) {
+        super(key, event, action);
+        checkState(action.getExec() instanceof PythonFunction);
+        checkState(
+                event instanceof PythonEvent,
+                "Python action only accept python event, but got " + event);
+        this.pythonActionExecutor = pythonActionExecutor;
+    }
+
+    public ActionTaskResult invoke() throws Exception {
+        LOG.debug(
+                "Try execute python action {} for event {} with key {}.",
+                action.getName(),
+                event,
+                key);
+        runnerContext.checkNoPendingEvents();
+
+        String pythonGeneratorRef =
+                pythonActionExecutor.executePythonFunction(
+                        (PythonFunction) action.getExec(), (PythonEvent) 
event, runnerContext);
+        // If a user-defined action uses an interface to submit asynchronous 
tasks, it will return a
+        // Python generator object instance upon its first execution. 
Otherwise, it means that no
+        // asynchronous tasks were submitted and the action has already 
completed.
+        if (pythonGeneratorRef != null) {
+            // The Python action generates a generator. We need to execute it 
once, which will
+            // submit an asynchronous task and return whether the action has 
been completed.
+            ActionTask tempGeneratedActionTask =
+                    new PythonGeneratorActionTask(
+                            key, event, action, pythonActionExecutor, 
pythonGeneratorRef);
+            tempGeneratedActionTask.setRunnerContext(runnerContext);
+            return tempGeneratedActionTask.invoke();
+        }
+        return new ActionTaskResult(true, runnerContext.drainEvents(), null);
+    }
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
new file mode 100644
index 0000000..bc13f20
--- /dev/null
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/operator/PythonGeneratorActionTask.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.python.operator;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.Action;
+import org.apache.flink.agents.runtime.operator.ActionTask;
+import org.apache.flink.agents.runtime.python.utils.PythonActionExecutor;
+
+/** An {@link ActionTask} wrapper a Python Generator to represent a code block 
in Python action. */
+public class PythonGeneratorActionTask extends PythonActionTask {
+    private final String pythonGeneratorRef;
+
+    public PythonGeneratorActionTask(
+            Object key,
+            Event event,
+            Action action,
+            PythonActionExecutor pythonActionExecutor,
+            String pythonGeneratorRef) {
+        super(key, event, action, pythonActionExecutor);
+        this.pythonGeneratorRef = pythonGeneratorRef;
+    }
+
+    @Override
+    public ActionTaskResult invoke() throws Exception {
+        LOG.debug(
+                "Try execute python generator action {} for event {} with key 
{}.",
+                action.getName(),
+                event,
+                key);
+        boolean finished = 
pythonActionExecutor.callPythonGenerator(pythonGeneratorRef);
+        ActionTask generatedActionTask = finished ? null : this;
+        return new ActionTaskResult(finished, runnerContext.drainEvents(), 
generatedActionTask);
+    }
+}
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
index ebd8c05..09b38ee 100644
--- 
a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
+++ 
b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/PythonActionExecutor.java
@@ -17,19 +17,15 @@
  */
 package org.apache.flink.agents.runtime.python.utils;
 
-import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.plan.PythonFunction;
+import org.apache.flink.agents.runtime.context.RunnerContextImpl;
 import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment;
 import org.apache.flink.agents.runtime.env.PythonEnvironmentManager;
-import org.apache.flink.agents.runtime.memory.MemoryObjectImpl;
-import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
-import org.apache.flink.agents.runtime.python.context.PythonRunnerContextImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
 import org.apache.flink.agents.runtime.utils.EventUtil;
-import org.apache.flink.api.common.state.MapState;
 import pemja.core.PythonInterpreter;
 
-import java.util.List;
+import java.util.concurrent.atomic.AtomicLong;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -40,65 +36,107 @@ public class PythonActionExecutor {
             "from flink_agents.plan import function\n"
                     + "from flink_agents.runtime import flink_runner_context\n"
                     + "from flink_agents.runtime import python_java_utils";
+
+    // =========== RUNNER CONTEXT ===========
     private static final String CREATE_FLINK_RUNNER_CONTEXT =
             "flink_runner_context.create_flink_runner_context";
+    private static final String FLINK_RUNNER_CONTEXT_REF_NAME_PREFIX = 
"flink_runner_context_";
+    private static final AtomicLong FLINK_RUNNER_CONTEXT_REF_ID = new 
AtomicLong(0);
+
+    // ========== ASYNC THREAD POOL ===========
+    private static final String CREATE_ASYNC_THREAD_POOL =
+            "flink_runner_context.create_async_thread_pool";
+    private static final String CLOSE_ASYNC_THREAD_POOL =
+            "flink_runner_context.close_async_thread_pool";
+    private static final String PYTHON_ASYNC_THREAD_POOL_REF_NAME = 
"python_async_thread_pool";
+    private static final AtomicLong PYTHON_ASYNC_THREAD_POOL_REF_ID = new 
AtomicLong(0);
+
+    // =========== PYTHON GENERATOR ===========
+    private static final String CALL_PYTHON_GENERATOR = 
"function.call_python_generator";
+    private static final String PYTHON_GENERATOR_VAR_NAME_PREFIX = 
"python_generator_";
+    private static final AtomicLong PYTHON_GENERATOR_VAR_ID = new 
AtomicLong(0);
+
+    // =========== PYTHON AND JAVA OBJECT CONVERT ===========
     private static final String CONVERT_TO_PYTHON_OBJECT =
             "python_java_utils.convert_to_python_object";
     private static final String WRAP_TO_INPUT_EVENT = 
"python_java_utils.wrap_to_input_event";
     private static final String GET_OUTPUT_FROM_OUTPUT_EVENT =
             "python_java_utils.get_output_from_output_event";
-    private static final String FLINK_RUNNER_CONTEXT_VAR_NAME = 
"flink_runner_context";
 
     private final PythonEnvironmentManager environmentManager;
     private final String agentPlanJson;
-    private PythonRunnerContextImpl runnerContext;
-
     private PythonInterpreter interpreter;
+    private String pythonAsyncThreadPoolObjectName;
 
     public PythonActionExecutor(PythonEnvironmentManager environmentManager, 
String agentPlanJson) {
         this.environmentManager = environmentManager;
         this.agentPlanJson = agentPlanJson;
     }
 
-    public void open(
-            MapState<String, MemoryObjectImpl.MemoryItem> shortTermMemState,
-            FlinkAgentsMetricGroupImpl metricGroup,
-            Runnable mailboxThreadChecker)
-            throws Exception {
+    public void open() throws Exception {
         environmentManager.open();
         EmbeddedPythonEnvironment env = environmentManager.createEnvironment();
 
         interpreter = env.getInterpreter();
         interpreter.exec(PYTHON_IMPORTS);
 
-        runnerContext =
-                new PythonRunnerContextImpl(shortTermMemState, metricGroup, 
mailboxThreadChecker);
-
-        // TODO: remove the set and get runner context after updating pemja to 
version 0.5.3
-        Object pythonRunnerContextObject =
-                interpreter.invoke(CREATE_FLINK_RUNNER_CONTEXT, runnerContext, 
agentPlanJson);
-        interpreter.set(FLINK_RUNNER_CONTEXT_VAR_NAME, 
pythonRunnerContextObject);
+        // TODO: remove the set and get thread pool after updating pemja to 
version 0.5.3. For more
+        // details, please refer to
+        //    https://github.com/apache/flink-agents/issues/83.
+        Object pythonAsyncThreadPool = 
interpreter.invoke(CREATE_ASYNC_THREAD_POOL);
+        this.pythonAsyncThreadPoolObjectName =
+                PYTHON_ASYNC_THREAD_POOL_REF_NAME
+                        + PYTHON_ASYNC_THREAD_POOL_REF_ID.incrementAndGet();
+        interpreter.set(pythonAsyncThreadPoolObjectName, 
pythonAsyncThreadPool);
     }
 
-    public List<Event> executePythonFunction(
-            PythonFunction function, PythonEvent event, String actionName) 
throws Exception {
+    /**
+     * Execute the Python function, which may return a Python generator that 
needs to be processed
+     * in the future. Due to an issue in Pemja regarding incorrect object 
reference counting, this
+     * may lead to garbage collection of the object. To prevent this, we use 
the set and get methods
+     * to manually increment the object's reference count, then return the 
name of the Python
+     * generator variable.
+     *
+     * @return The name of the Python generator variable. It may be null if 
the Python function does
+     *     not return a generator.
+     */
+    public String executePythonFunction(
+            PythonFunction function, PythonEvent event, RunnerContextImpl 
runnerContext)
+            throws Exception {
         runnerContext.checkNoPendingEvents();
-        runnerContext.setActionName(actionName);
         function.setInterpreter(interpreter);
 
-        // TODO: remove the set and get runner context after updating pemja to 
version 0.5.3
-        Object pythonRunnerContextObject = 
interpreter.get(FLINK_RUNNER_CONTEXT_VAR_NAME);
+        // TODO: remove the set and get runner context after updating pemja to 
version 0.5.3. For
+        // more details, please refer to 
https://github.com/apache/flink-agents/issues/83.
+        Object pythonRunnerContextObject =
+                interpreter.invoke(
+                        CREATE_FLINK_RUNNER_CONTEXT,
+                        runnerContext,
+                        agentPlanJson,
+                        interpreter.get(pythonAsyncThreadPoolObjectName));
+        String pythonRunnerContextObjectName =
+                FLINK_RUNNER_CONTEXT_REF_NAME_PREFIX
+                        + FLINK_RUNNER_CONTEXT_REF_ID.incrementAndGet();
+        interpreter.set(pythonRunnerContextObjectName, 
pythonRunnerContextObject);
 
         Object pythonEventObject = 
interpreter.invoke(CONVERT_TO_PYTHON_OBJECT, event.getEvent());
 
         try {
-            function.call(pythonEventObject, pythonRunnerContextObject);
+            Object calledResult = function.call(pythonEventObject, 
pythonRunnerContextObject);
+            if (calledResult == null) {
+                return null;
+            } else {
+                // must be a generator
+                String pythonGeneratorRef =
+                        PYTHON_GENERATOR_VAR_NAME_PREFIX
+                                + PYTHON_GENERATOR_VAR_ID.incrementAndGet();
+                interpreter.set(pythonGeneratorRef, calledResult);
+                return pythonGeneratorRef;
+            }
         } catch (Exception e) {
             runnerContext.drainEvents();
             throw new PythonActionExecutionException("Failed to execute Python 
action", e);
         }
-
-        return runnerContext.drainEvents();
     }
 
     public PythonEvent wrapToInputEvent(Object eventData) {
@@ -112,6 +150,29 @@ public class PythonActionExecutor {
         return interpreter.invoke(GET_OUTPUT_FROM_OUTPUT_EVENT, 
pythonOutputEvent);
     }
 
+    /**
+     * Invokes the next step of a Python generator.
+     *
+     * <p>This method is typically used after initializing or resuming a 
Python generator that was
+     * created via a user-defined action involving asynchronous execution.
+     *
+     * @param pythonGeneratorRef the reference name of the Python generator 
object stored in the
+     *     interpreter's context
+     * @return true if the generator has completed; false otherwise
+     */
+    public boolean callPythonGenerator(String pythonGeneratorRef) {
+        // Calling next(generator) in Python returns a tuple of (finished, 
output).
+        Object pythonGenerator = interpreter.get(pythonGeneratorRef);
+        Object invokeResult = interpreter.invoke(CALL_PYTHON_GENERATOR, 
pythonGenerator);
+        checkState(invokeResult.getClass().isArray() && ((Object[]) 
invokeResult).length == 2);
+        return (boolean) ((Object[]) invokeResult)[0];
+    }
+
+    public void close() throws Exception {
+        interpreter.invoke(
+                CLOSE_ASYNC_THREAD_POOL, 
interpreter.get(pythonAsyncThreadPoolObjectName));
+    }
+
     /** Failed to execute Python action. */
     public static class PythonActionExecutionException extends Exception {
         public PythonActionExecutionException(String message, Throwable cause) 
{
diff --git 
a/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java 
b/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java
new file mode 100644
index 0000000..1beb12c
--- /dev/null
+++ b/runtime/src/main/java/org/apache/flink/agents/runtime/utils/StateUtil.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.utils;
+
+import org.apache.flink.api.common.state.ListState;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/** Some utilities related to Flink state. */
+public class StateUtil {
+
+    /**
+     * Checks whether the provided ListState contains any elements.
+     *
+     * @param listState the ListState instance to check
+     * @return true if the state is not empty; false otherwise
+     * @throws Exception if an I/O error occurs while reading state
+     */
+    public static boolean listStateNotEmpty(ListState<?> listState) throws 
Exception {
+        return listState.get() != null && listState.get().iterator().hasNext();
+    }
+
+    /**
+     * Removes all occurrences of the specified element from the given 
ListState.
+     *
+     * @param listState the ListState instance to modify
+     * @param element the element to remove
+     * @return the number of elements removed
+     * @throws Exception if an I/O error occurs while reading/writing state
+     */
+    public static <T> int removeFromListState(ListState<T> listState, T 
element) throws Exception {
+        Iterator<T> listStateIterator = listState.get().iterator();
+        if (!listStateIterator.hasNext()) {
+            return 0;
+        }
+
+        int removedElementCount = 0;
+        List<T> remaining = new ArrayList<>();
+        while (listStateIterator.hasNext()) {
+            T next = listStateIterator.next();
+            if (next.equals(element)) {
+                removedElementCount++;
+                continue;
+            }
+            remaining.add(next);
+        }
+        listState.clear();
+        listState.update(remaining);
+        return removedElementCount;
+    }
+
+    /**
+     * Removes and returns the first element from the ListState.
+     *
+     * @param listState the ListState instance to poll from
+     * @return the first element of the list, or null if the list is empty
+     * @throws Exception if an I/O error occurs while reading/writing state
+     */
+    public static <T> T pollFromListState(ListState<T> listState) throws 
Exception {
+        Iterator<T> listStateIterator = listState.get().iterator();
+        if (!listStateIterator.hasNext()) {
+            return null;
+        }
+
+        T polled = listStateIterator.next();
+        List<T> remaining = new ArrayList<>();
+        while (listStateIterator.hasNext()) {
+            remaining.add(listStateIterator.next());
+        }
+        listState.clear();
+        listState.update(remaining);
+        return polled;
+    }
+}
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java 
b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
index 1617a19..ebfc15f 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/CompileUtilsTest.java
@@ -25,12 +25,12 @@ import 
org.apache.flink.streaming.api.datastream.DataStreamSource;
 import org.apache.flink.streaming.api.datastream.KeyedStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.util.CloseableIterator;
+import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
-import java.util.stream.LongStream;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -39,16 +39,23 @@ public class CompileUtilsTest {
 
     private static final Long TEST_SEQUENCE_START = 0L;
     private static final Long TEST_SEQUENCE_END = 100L;
+    private static final Long TEST_SEQUENCE_REPEAT = 3L;
     // Agent logic: x -> (x + 1) * 2
     private static final AgentPlan TEST_AGENT_PLAN =
             ActionExecutionOperatorTest.TestAgent.getAgentPlan(false);
+    private static List<Long> testSequence;
+
+    @BeforeAll
+    static void setup() {
+        testSequence = getTestSequence();
+        testSequence.sort(Long::compareTo);
+    }
 
     @Test
     void testJavaNoKeyedStreamConnectToAgent() throws Exception {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
 
-        DataStreamSource<Long> inputStream =
-                env.fromSequence(TEST_SEQUENCE_START, TEST_SEQUENCE_END);
+        DataStreamSource<Long> inputStream = env.fromData(testSequence);
         DataStream<Object> agentOutputStream =
                 CompileUtils.connectToAgent(
                         inputStream,
@@ -75,11 +82,10 @@ public class CompileUtilsTest {
     void testJavaKeyedStreamConnectToAgent() throws Exception {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
 
-        KeyedStream<Long, Long> keyedInputStream =
-                env.fromSequence(TEST_SEQUENCE_START, 
TEST_SEQUENCE_END).keyBy(x -> x);
-        DataStream<Object> agentOutputStream =
+        KeyedStream<Long, Long> keyedInputStream = 
env.fromData(testSequence).keyBy(x -> x);
+        DataStream<Object> workflowOutputStream =
                 CompileUtils.connectToAgent(keyedInputStream, TEST_AGENT_PLAN);
-        DataStream<Long> resultStream = agentOutputStream.map(x -> (long) x + 
1);
+        DataStream<Long> resultStream = workflowOutputStream.map(x -> (long) x 
+ 1);
 
         List<Long> resultList = new ArrayList<>();
         try (CloseableIterator<Long> iterator = 
resultStream.executeAndCollect()) {
@@ -90,12 +96,19 @@ public class CompileUtilsTest {
         checkResult(resultList);
     }
 
-    private void checkResult(List<Long> resultList) {
+    private static List<Long> getTestSequence() {
+        List<Long> testSequence = new ArrayList<>();
+        for (int i = 0; i < TEST_SEQUENCE_REPEAT; i++) {
+            for (long j = TEST_SEQUENCE_START; j <= TEST_SEQUENCE_END; j++) {
+                testSequence.add(j);
+            }
+        }
+        return testSequence;
+    }
+
+    private static void checkResult(List<Long> resultList) {
         List<Long> expectedResultList =
-                LongStream.rangeClosed(TEST_SEQUENCE_START, TEST_SEQUENCE_END)
-                        .boxed()
-                        .map(x -> (x + 1) * 2 + 1)
-                        .collect(Collectors.toList());
+                testSequence.stream().map(x -> (x + 1) * 2 + 
1).collect(Collectors.toList());
 
         assertThat(resultList).isEqualTo(expectedResultList);
     }
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
index 2559bab..90c558c 100644
--- 
a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
+++ 
b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java
@@ -28,6 +28,7 @@ import org.apache.flink.agents.plan.JavaFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.util.ExceptionUtils;
 import org.junit.jupiter.api.Test;
@@ -54,19 +55,95 @@ public class ActionExecutionOperatorTest {
                         (KeySelector<Long, Long>) value -> value,
                         TypeInformation.of(Long.class))) {
             testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
             testHarness.processElement(new StreamRecord<>(0L));
+            operator.waitInFlightEventsFinished();
             List<StreamRecord<Object>> recordOutput =
                     (List<StreamRecord<Object>>) testHarness.getRecordOutput();
             assertThat(recordOutput.size()).isEqualTo(1);
             assertThat(recordOutput.get(0).getValue()).isEqualTo(2L);
 
             testHarness.processElement(new StreamRecord<>(1L));
+            operator.waitInFlightEventsFinished();
             recordOutput = (List<StreamRecord<Object>>) 
testHarness.getRecordOutput();
             assertThat(recordOutput.size()).isEqualTo(2);
             assertThat(recordOutput.get(1).getValue()).isEqualTo(4L);
         }
     }
 
+    @Test
+    void testSameKeyDataAreProcessedInOrder() throws Exception {
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new 
ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Process input data 1 with key 0
+            testHarness.processElement(new StreamRecord<>(0L));
+            // Process input data 2, which has the same key (0)
+            testHarness.processElement(new StreamRecord<>(0L));
+            // Since both pieces of data share the same key, we should 
consolidate them and process
+            // only input data 1.
+            // This means we need one mail to execute the action1 action for 
input data 1.
+            assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 1);
+            // After executing this mail, we will have another mail to execute 
the action2 action
+            // for input data 1.
+            assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 1);
+            // Once the above mails are executed, we should get a single 
output result from input
+            // data 1.
+            List<StreamRecord<Object>> recordOutput =
+                    (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+            assertThat(recordOutput.size()).isEqualTo(1);
+            assertThat(recordOutput.get(0).getValue()).isEqualTo(2L);
+
+            // After the processing of input data 1 is finished, we can 
proceed to process input
+            // data 2 and obtain its result.
+            operator.waitInFlightEventsFinished();
+            recordOutput = (List<StreamRecord<Object>>) 
testHarness.getRecordOutput();
+            assertThat(recordOutput.size()).isEqualTo(2);
+            assertThat(recordOutput.get(1).getValue()).isEqualTo(2L);
+        }
+    }
+
+    @Test
+    void testDifferentKeyDataCanRunConcurrently() throws Exception {
+        try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
+                new KeyedOneInputStreamOperatorTestHarness<>(
+                        new 
ActionExecutionOperatorFactory(TestAgent.getAgentPlan(false), true),
+                        (KeySelector<Long, Long>) value -> value,
+                        TypeInformation.of(Long.class))) {
+            testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
+
+            // Process input data 1 with key 0
+            testHarness.processElement(new StreamRecord<>(0L));
+            // Process input data 2, which has the different key (1)
+            testHarness.processElement(new StreamRecord<>(1L));
+            // Since the two input data items have different keys, they can be 
processed in
+            // parallel.
+            // As a result, we should have two separate mails to execute the 
action1 for each of
+            // them.
+            assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 2);
+            // After these two mails are executed, there should be another two 
mails — one for each
+            // input data item — to execute the corresponding action2.
+            assertMailboxSizeAndRun(testHarness.getTaskMailbox(), 2);
+            // Once both action2 operations are completed, we should receive 
two output data items,
+            // each corresponding to one of the original inputs.
+            List<StreamRecord<Object>> recordOutput =
+                    (List<StreamRecord<Object>>) testHarness.getRecordOutput();
+            assertThat(recordOutput.size()).isEqualTo(2);
+            assertThat(recordOutput.get(0).getValue()).isEqualTo(2L);
+            assertThat(recordOutput.get(1).getValue()).isEqualTo(4L);
+        }
+    }
+
     @Test
     void testMemoryAccessProhibitedOutsideMailboxThread() throws Exception {
         try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> 
testHarness =
@@ -75,8 +152,12 @@ public class ActionExecutionOperatorTest {
                         (KeySelector<Long, Long>) value -> value,
                         TypeInformation.of(Long.class))) {
             testHarness.open();
+            ActionExecutionOperator<Long, Object> operator =
+                    (ActionExecutionOperator<Long, Object>) 
testHarness.getOperator();
 
-            assertThatThrownBy(() -> testHarness.processElement(new 
StreamRecord<>(0L)))
+            testHarness.processElement(new StreamRecord<>(0L));
+            assertThatThrownBy(() -> operator.waitInFlightEventsFinished())
+                    
.hasCauseInstanceOf(ActionExecutionOperator.ActionTaskExecutionException.class)
                     .rootCause()
                     .hasMessageContaining("Expected to be running on the task 
mailbox thread");
         }
@@ -180,4 +261,12 @@ public class ActionExecutionOperatorTest {
             return null;
         }
     }
+
+    private static void assertMailboxSizeAndRun(TaskMailbox mailbox, int 
expectedSize)
+            throws Exception {
+        assertThat(mailbox.size()).isEqualTo(expectedSize);
+        for (int i = 0; i < expectedSize; i++) {
+            mailbox.take(TaskMailbox.MIN_PRIORITY).run();
+        }
+    }
 }

Reply via email to