xintongsong commented on code in PR #422:
URL: https://github.com/apache/flink-agents/pull/422#discussion_r2674653474


##########
runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallRecord.java:
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * Represents a record of a function call execution for fine-grained durable 
execution.
+ *
+ * <p>This class stores the execution result of a single {@code execute} or 
{@code execute_async}
+ * call, enabling recovery without re-execution when the same call is 
encountered during job
+ * recovery.
+ *
+ * <p>During recovery, the success or failure of the original call is 
determined by checking whether
+ * {@code exceptionPayload} is null.
+ */
+public class CallRecord {

Review Comment:
   I'd suggest to name this `CallResult`, which makes it explicit that this 
class represents a result of call execution.



##########
runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java:
##########
@@ -888,7 +904,59 @@ private void maybePersistTaskResult(
         for (Event outputEvent : actionTaskResult.getOutputEvents()) {
             actionState.addEvent(outputEvent);
         }
+
+        // Mark the action as completed and clear call records
+        // This indicates that recovery should skip the entire action
+        actionState.markCompleted();
+
         actionStateStore.put(key, sequenceNum, action, event, actionState);
+
+        // Clear recovery state for Python actions
+        if (context instanceof PythonRunnerContextImpl) {
+            ((PythonRunnerContextImpl) context).clearRecoveryState();
+        }
+    }
+
+    /**
+     * Sets up the recovery state for Python actions to enable fine-grained 
durable execution.
+     *
+     * <p>This method initializes the Python runner context with the 
ActionState, allowing Python
+     * execute/execute_async calls to:
+     *
+     * <ul>
+     *   <li>Skip re-execution for already completed calls during recovery
+     *   <li>Persist CallRecords after each code block completion
+     * </ul>
+     */
+    private void setupRecoveryStateForPythonAction(ActionTask actionTask, 
ActionState actionState) {
+        Preconditions.checkState(
+                actionTask.getRunnerContext() instanceof 
PythonRunnerContextImpl,
+                "Python action tasks must have PythonRunnerContextImpl");
+
+        if (actionState == null || actionStateStore == null) {
+            return;
+        }
+
+        PythonRunnerContextImpl pythonContext =
+                (PythonRunnerContextImpl) actionTask.getRunnerContext();
+
+        // Capture variables for the persister callback
+        final Object key = actionTask.getKey();
+        final Action action = actionTask.action;
+        final Event event = actionTask.event;
+
+        // Initialize recovery state with ActionState and a callback to 
persist after each code
+        // block
+        pythonContext.initRecoveryState(
+                actionState,
+                () -> {
+                    try {
+                        actionStateStore.put(
+                                key, sequenceNumberKState.value(), action, 
event, actionState);

Review Comment:
   `sequenceNumberKState` might change when the callback is called.



##########
runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallRecord.java:
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * Represents a record of a function call execution for fine-grained durable 
execution.
+ *
+ * <p>This class stores the execution result of a single {@code execute} or 
{@code execute_async}
+ * call, enabling recovery without re-execution when the same call is 
encountered during job
+ * recovery.
+ *
+ * <p>During recovery, the success or failure of the original call is 
determined by checking whether
+ * {@code exceptionPayload} is null.
+ */
+public class CallRecord {
+
+    /** Function identifier: module+qualname for Python, or method signature 
for Java. */
+    private final String functionId;
+
+    /** Stable digest of the serialized arguments for validation during 
recovery. */
+    private final String argsDigest;
+
+    /** Serialized return value of the function call (null if the call threw 
an exception). */
+    private byte[] resultPayload;
+
+    /** Serialized exception info if the call failed (null if the call 
succeeded). */
+    private byte[] exceptionPayload;
+
+    /** Default constructor for deserialization. */
+    public CallRecord() {
+        this.functionId = null;
+        this.argsDigest = null;
+        this.resultPayload = null;
+        this.exceptionPayload = null;
+    }
+
+    /**
+     * Constructs a CallRecord for a successful function call.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized return value
+     */
+    public CallRecord(String functionId, String argsDigest, byte[] 
resultPayload) {
+        this.functionId = functionId;
+        this.argsDigest = argsDigest;
+        this.resultPayload = resultPayload;
+        this.exceptionPayload = null;
+    }
+
+    /**
+     * Constructs a CallRecord with explicit result and exception payloads.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized return value (null if exception 
occurred)
+     * @param exceptionPayload the serialized exception (null if call 
succeeded)
+     */
+    public CallRecord(
+            String functionId, String argsDigest, byte[] resultPayload, byte[] 
exceptionPayload) {
+        this.functionId = functionId;
+        this.argsDigest = argsDigest;
+        this.resultPayload = resultPayload;
+        this.exceptionPayload = exceptionPayload;
+    }
+
+    /**
+     * Creates a CallRecord for a failed function call.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param exceptionPayload the serialized exception
+     * @return a new CallRecord representing a failed call
+     */
+    public static CallRecord ofException(
+            String functionId, String argsDigest, byte[] exceptionPayload) {
+        return new CallRecord(functionId, argsDigest, null, exceptionPayload);
+    }
+
+    public String getFunctionId() {
+        return functionId;
+    }
+
+    public String getArgsDigest() {
+        return argsDigest;
+    }
+
+    public byte[] getResultPayload() {
+        return resultPayload;
+    }
+
+    public void setResultPayload(byte[] resultPayload) {
+        this.resultPayload = resultPayload;
+    }
+
+    public byte[] getExceptionPayload() {
+        return exceptionPayload;
+    }
+
+    public void setExceptionPayload(byte[] exceptionPayload) {
+        this.exceptionPayload = exceptionPayload;
+    }

Review Comment:
   Why do we need to set the payloads after construction? Can these fields be 
`final`?



##########
python/flink_agents/runtime/local_runner.py:
##########
@@ -179,6 +179,22 @@ def action_metric_group(self) -> MetricGroup:
         err_msg = "Metric mechanism is not supported for local agent execution 
yet."
         raise NotImplementedError(err_msg)
 
+    @override
+    def execute(
+        self,
+        func: Callable[[Any], Any],
+        *args: Any,
+        **kwargs: Any,
+    ) -> Any:
+        """Synchronously execute the provided function. Access to memory
+        is prohibited within the function.
+
+        Note: Local runner does not support durable execution, so recovery
+        is not available.
+        """
+        return func(*args, **kwargs)

Review Comment:
   If we are renaming this to `durable_execute`, we also need to print a 
warning log here. There's no need to fail the call, as we don't want users to 
change their codes when switching between local execution & remote execution.



##########
python/flink_agents/api/runner_context.py:
##########
@@ -186,6 +186,48 @@ def action_metric_group(self) -> MetricGroup:
             The individual metric group specific to the current action.
         """
 
+    @abstractmethod
+    def execute(

Review Comment:
   I'd suggest the name `durable_execute()`



##########
python/flink_agents/api/runner_context.py:
##########
@@ -196,6 +238,14 @@ def execute_async(
         """Asynchronously execute the provided function. Access to memory
         is prohibited within the function.
 
+        The result of the function will be stored and returned when the same
+        execute_async call is made again during job recovery. The arguments
+        and the result must be serializable.
+
+        The action that calls this API should be deterministic, meaning that it
+        will always make the execute_async call with the same arguments and in
+        the same order during job recovery. Otherwise, the behavior is 
undefined.
+

Review Comment:
   I don't see the necessity, but it's still possible that we support 
non-durable async execution in future, which simply execute in a separate 
thread/coroutine but the result should not be reused for replaying.
   
   E.g., getting a token from remote, which might be already expired during the 
replay.



##########
python/flink_agents/api/runner_context.py:
##########
@@ -196,6 +238,14 @@ def execute_async(
         """Asynchronously execute the provided function. Access to memory
         is prohibited within the function.
 
+        The result of the function will be stored and returned when the same
+        execute_async call is made again during job recovery. The arguments
+        and the result must be serializable.
+
+        The action that calls this API should be deterministic, meaning that it
+        will always make the execute_async call with the same arguments and in
+        the same order during job recovery. Otherwise, the behavior is 
undefined.
+

Review Comment:
   I'd suggest to name this method `durable_execute_async`.



##########
runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java:
##########
@@ -48,4 +78,120 @@ public void sendEvent(String type, byte[] event, String 
eventString) {
         // this method will be invoked by PythonActionExecutor's python 
interpreter.
         sendEvent(new PythonEvent(event, type, eventString));
     }
+
+    /**
+     * Initializes the recovery state for the current action.
+     *
+     * @param actionState the ActionState for the current action (used for 
adding CallRecords)
+     * @param actionStatePersister callback to persist ActionState after each 
code block completion
+     */
+    public void initRecoveryState(ActionState actionState, Runnable 
actionStatePersister) {
+        this.currentCallIndex = 0;
+        this.currentActionState = actionState;
+        this.actionStatePersister = actionStatePersister;
+        this.recoveryCallRecords =
+                actionState != null && actionState.getCallRecords() != null
+                        ? new ArrayList<>(actionState.getCallRecords())
+                        : new ArrayList<>();
+    }
+
+    /**
+     * Gets the current call index.
+     *
+     * @return the current call index
+     */
+    public int getCurrentCallIndex() {
+        return currentCallIndex;
+    }
+
+    /**
+     * Tries to get a cached CallRecord for recovery.
+     *
+     * <p>This method is called by Python's execute/execute_async to check if 
a previous result
+     * exists. If found and validated, the cached result is returned.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @return array containing [isHit (boolean), resultPayload (byte[]), 
exceptionPayload
+     *     (byte[])], or null if miss
+     */
+    public Object[] tryGetCachedCallRecord(String functionId, String 
argsDigest) {
+        mailboxThreadChecker.run();
+
+        if (currentCallIndex < recoveryCallRecords.size()) {
+            CallRecord record = recoveryCallRecords.get(currentCallIndex);
+
+            if (record.matches(functionId, argsDigest)) {
+                LOG.debug(
+                        "CallRecord hit at index {}: functionId={}, 
argsDigest={}",
+                        currentCallIndex,
+                        functionId,
+                        argsDigest);
+                currentCallIndex++;
+                return new Object[] {true, record.getResultPayload(), 
record.getExceptionPayload()};
+            } else {
+                // Non-deterministic call detected, clear subsequent records
+                LOG.warn(
+                        "Non-deterministic call detected at index {}: expected 
functionId={}, argsDigest={}, "
+                                + "but got functionId={}, argsDigest={}. 
Clearing subsequent records.",
+                        currentCallIndex,
+                        record.getFunctionId(),
+                        record.getArgsDigest(),
+                        functionId,
+                        argsDigest);
+                clearCallRecordsFromCurrentIndex();
+            }
+        }
+
+        return null;
+    }
+
+    /**
+     * Records a completed call and persists the ActionState.
+     *
+     * <p>This method is called by Python after each execute/execute_async 
call completes.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @param resultPayload the serialized result (null if exception)
+     * @param exceptionPayload the serialized exception (null if success)
+     */
+    public void recordCallCompletion(
+            String functionId, String argsDigest, byte[] resultPayload, byte[] 
exceptionPayload) {
+        mailboxThreadChecker.run();
+
+        CallRecord callRecord =
+                new CallRecord(functionId, argsDigest, resultPayload, 
exceptionPayload);
+
+        if (currentActionState != null && actionStatePersister != null) {
+            currentActionState.addCallRecord(callRecord);
+            actionStatePersister.run();
+            LOG.debug(
+                    "Recorded and persisted CallRecord at index {}: 
functionId={}, argsDigest={}",
+                    currentCallIndex,
+                    functionId,
+                    argsDigest);
+        }

Review Comment:
   Why would `currentActionState` and `actionStatePersister` be `null` when 
this method is called? Shall we use assertion here?



##########
runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java:
##########
@@ -48,4 +78,120 @@ public void sendEvent(String type, byte[] event, String 
eventString) {
         // this method will be invoked by PythonActionExecutor's python 
interpreter.
         sendEvent(new PythonEvent(event, type, eventString));
     }
+
+    /**
+     * Initializes the recovery state for the current action.
+     *
+     * @param actionState the ActionState for the current action (used for 
adding CallRecords)
+     * @param actionStatePersister callback to persist ActionState after each 
code block completion
+     */
+    public void initRecoveryState(ActionState actionState, Runnable 
actionStatePersister) {
+        this.currentCallIndex = 0;
+        this.currentActionState = actionState;
+        this.actionStatePersister = actionStatePersister;
+        this.recoveryCallRecords =
+                actionState != null && actionState.getCallRecords() != null
+                        ? new ArrayList<>(actionState.getCallRecords())
+                        : new ArrayList<>();
+    }
+
+    /**
+     * Gets the current call index.
+     *
+     * @return the current call index
+     */
+    public int getCurrentCallIndex() {
+        return currentCallIndex;
+    }
+
+    /**
+     * Tries to get a cached CallRecord for recovery.
+     *
+     * <p>This method is called by Python's execute/execute_async to check if 
a previous result
+     * exists. If found and validated, the cached result is returned.
+     *
+     * @param functionId the function identifier
+     * @param argsDigest the digest of serialized arguments
+     * @return array containing [isHit (boolean), resultPayload (byte[]), 
exceptionPayload
+     *     (byte[])], or null if miss
+     */
+    public Object[] tryGetCachedCallRecord(String functionId, String 
argsDigest) {

Review Comment:
   I'd suggest the name `matchNextOrClearSubsequentCallRecord`. Otherwise, the 
clearing is super implicit.



##########
runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java:
##########
@@ -20,21 +20,51 @@
 import org.apache.flink.agents.api.Event;
 import org.apache.flink.agents.api.context.RunnerContext;
 import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.agents.runtime.actionstate.ActionState;
+import org.apache.flink.agents.runtime.actionstate.CallRecord;
 import org.apache.flink.agents.runtime.context.RunnerContextImpl;
 import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl;
 import org.apache.flink.agents.runtime.python.event.PythonEvent;
 import org.apache.flink.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nullable;
 import javax.annotation.concurrent.NotThreadSafe;
 
-/** A specialized {@link RunnerContext} that is specifically used when 
executing Python actions. */
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A specialized {@link RunnerContext} that is specifically used when 
executing Python actions.
+ *
+ * <p>This class provides fine-grained durable execution support by managing 
{@link CallRecord}s
+ * that store the results of {@code execute} and {@code execute_async} calls. 
During recovery, these
+ * records are used to skip re-execution of already completed calls.
+ */
 @NotThreadSafe
 public class PythonRunnerContextImpl extends RunnerContextImpl {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(PythonRunnerContextImpl.class);
+
+    /** Current call index within the action, used for matching CallRecords 
during recovery. */
+    private int currentCallIndex = 0;
+
+    /** List of existing CallRecords loaded during recovery. */
+    private List<CallRecord> recoveryCallRecords;
+
+    /** The current ActionState being built during action execution. */
+    @Nullable private ActionState currentActionState;
+
+    /** Callback to persist ActionState after each code block completion. */
+    @Nullable private Runnable actionStatePersister;

Review Comment:
   It's kind of weird that these are maintained in `PythonRunnerContextImpl`.
   1. These are not only needed by python. We will also support this for Java.
   2. These are actually ActionTask-specific. I think we should handle this in 
`RunnerContextImpl.switchActionContext()`. We might want to actually introduce 
an `ActionTaskContext`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to