xintongsong commented on code in PR #422: URL: https://github.com/apache/flink-agents/pull/422#discussion_r2674653474
########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallRecord.java: ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.Arrays; +import java.util.Objects; + +/** + * Represents a record of a function call execution for fine-grained durable execution. + * + * <p>This class stores the execution result of a single {@code execute} or {@code execute_async} + * call, enabling recovery without re-execution when the same call is encountered during job + * recovery. + * + * <p>During recovery, the success or failure of the original call is determined by checking whether + * {@code exceptionPayload} is null. + */ +public class CallRecord { Review Comment: I'd suggest to name this `CallResult`, which makes it explicit that this class represents a result of call execution. ########## runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java: ########## @@ -888,7 +904,59 @@ private void maybePersistTaskResult( for (Event outputEvent : actionTaskResult.getOutputEvents()) { actionState.addEvent(outputEvent); } + + // Mark the action as completed and clear call records + // This indicates that recovery should skip the entire action + actionState.markCompleted(); + actionStateStore.put(key, sequenceNum, action, event, actionState); + + // Clear recovery state for Python actions + if (context instanceof PythonRunnerContextImpl) { + ((PythonRunnerContextImpl) context).clearRecoveryState(); + } + } + + /** + * Sets up the recovery state for Python actions to enable fine-grained durable execution. + * + * <p>This method initializes the Python runner context with the ActionState, allowing Python + * execute/execute_async calls to: + * + * <ul> + * <li>Skip re-execution for already completed calls during recovery + * <li>Persist CallRecords after each code block completion + * </ul> + */ + private void setupRecoveryStateForPythonAction(ActionTask actionTask, ActionState actionState) { + Preconditions.checkState( + actionTask.getRunnerContext() instanceof PythonRunnerContextImpl, + "Python action tasks must have PythonRunnerContextImpl"); + + if (actionState == null || actionStateStore == null) { + return; + } + + PythonRunnerContextImpl pythonContext = + (PythonRunnerContextImpl) actionTask.getRunnerContext(); + + // Capture variables for the persister callback + final Object key = actionTask.getKey(); + final Action action = actionTask.action; + final Event event = actionTask.event; + + // Initialize recovery state with ActionState and a callback to persist after each code + // block + pythonContext.initRecoveryState( + actionState, + () -> { + try { + actionStateStore.put( + key, sequenceNumberKState.value(), action, event, actionState); Review Comment: `sequenceNumberKState` might change when the callback is called. ########## runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/CallRecord.java: ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import com.fasterxml.jackson.annotation.JsonIgnore; + +import java.util.Arrays; +import java.util.Objects; + +/** + * Represents a record of a function call execution for fine-grained durable execution. + * + * <p>This class stores the execution result of a single {@code execute} or {@code execute_async} + * call, enabling recovery without re-execution when the same call is encountered during job + * recovery. + * + * <p>During recovery, the success or failure of the original call is determined by checking whether + * {@code exceptionPayload} is null. + */ +public class CallRecord { + + /** Function identifier: module+qualname for Python, or method signature for Java. */ + private final String functionId; + + /** Stable digest of the serialized arguments for validation during recovery. */ + private final String argsDigest; + + /** Serialized return value of the function call (null if the call threw an exception). */ + private byte[] resultPayload; + + /** Serialized exception info if the call failed (null if the call succeeded). */ + private byte[] exceptionPayload; + + /** Default constructor for deserialization. */ + public CallRecord() { + this.functionId = null; + this.argsDigest = null; + this.resultPayload = null; + this.exceptionPayload = null; + } + + /** + * Constructs a CallRecord for a successful function call. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized return value + */ + public CallRecord(String functionId, String argsDigest, byte[] resultPayload) { + this.functionId = functionId; + this.argsDigest = argsDigest; + this.resultPayload = resultPayload; + this.exceptionPayload = null; + } + + /** + * Constructs a CallRecord with explicit result and exception payloads. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized return value (null if exception occurred) + * @param exceptionPayload the serialized exception (null if call succeeded) + */ + public CallRecord( + String functionId, String argsDigest, byte[] resultPayload, byte[] exceptionPayload) { + this.functionId = functionId; + this.argsDigest = argsDigest; + this.resultPayload = resultPayload; + this.exceptionPayload = exceptionPayload; + } + + /** + * Creates a CallRecord for a failed function call. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param exceptionPayload the serialized exception + * @return a new CallRecord representing a failed call + */ + public static CallRecord ofException( + String functionId, String argsDigest, byte[] exceptionPayload) { + return new CallRecord(functionId, argsDigest, null, exceptionPayload); + } + + public String getFunctionId() { + return functionId; + } + + public String getArgsDigest() { + return argsDigest; + } + + public byte[] getResultPayload() { + return resultPayload; + } + + public void setResultPayload(byte[] resultPayload) { + this.resultPayload = resultPayload; + } + + public byte[] getExceptionPayload() { + return exceptionPayload; + } + + public void setExceptionPayload(byte[] exceptionPayload) { + this.exceptionPayload = exceptionPayload; + } Review Comment: Why do we need to set the payloads after construction? Can these fields be `final`? ########## python/flink_agents/runtime/local_runner.py: ########## @@ -179,6 +179,22 @@ def action_metric_group(self) -> MetricGroup: err_msg = "Metric mechanism is not supported for local agent execution yet." raise NotImplementedError(err_msg) + @override + def execute( + self, + func: Callable[[Any], Any], + *args: Any, + **kwargs: Any, + ) -> Any: + """Synchronously execute the provided function. Access to memory + is prohibited within the function. + + Note: Local runner does not support durable execution, so recovery + is not available. + """ + return func(*args, **kwargs) Review Comment: If we are renaming this to `durable_execute`, we also need to print a warning log here. There's no need to fail the call, as we don't want users to change their codes when switching between local execution & remote execution. ########## python/flink_agents/api/runner_context.py: ########## @@ -186,6 +186,48 @@ def action_metric_group(self) -> MetricGroup: The individual metric group specific to the current action. """ + @abstractmethod + def execute( Review Comment: I'd suggest the name `durable_execute()` ########## python/flink_agents/api/runner_context.py: ########## @@ -196,6 +238,14 @@ def execute_async( """Asynchronously execute the provided function. Access to memory is prohibited within the function. + The result of the function will be stored and returned when the same + execute_async call is made again during job recovery. The arguments + and the result must be serializable. + + The action that calls this API should be deterministic, meaning that it + will always make the execute_async call with the same arguments and in + the same order during job recovery. Otherwise, the behavior is undefined. + Review Comment: I don't see the necessity, but it's still possible that we support non-durable async execution in future, which simply execute in a separate thread/coroutine but the result should not be reused for replaying. E.g., getting a token from remote, which might be already expired during the replay. ########## python/flink_agents/api/runner_context.py: ########## @@ -196,6 +238,14 @@ def execute_async( """Asynchronously execute the provided function. Access to memory is prohibited within the function. + The result of the function will be stored and returned when the same + execute_async call is made again during job recovery. The arguments + and the result must be serializable. + + The action that calls this API should be deterministic, meaning that it + will always make the execute_async call with the same arguments and in + the same order during job recovery. Otherwise, the behavior is undefined. + Review Comment: I'd suggest to name this method `durable_execute_async`. ########## runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java: ########## @@ -48,4 +78,120 @@ public void sendEvent(String type, byte[] event, String eventString) { // this method will be invoked by PythonActionExecutor's python interpreter. sendEvent(new PythonEvent(event, type, eventString)); } + + /** + * Initializes the recovery state for the current action. + * + * @param actionState the ActionState for the current action (used for adding CallRecords) + * @param actionStatePersister callback to persist ActionState after each code block completion + */ + public void initRecoveryState(ActionState actionState, Runnable actionStatePersister) { + this.currentCallIndex = 0; + this.currentActionState = actionState; + this.actionStatePersister = actionStatePersister; + this.recoveryCallRecords = + actionState != null && actionState.getCallRecords() != null + ? new ArrayList<>(actionState.getCallRecords()) + : new ArrayList<>(); + } + + /** + * Gets the current call index. + * + * @return the current call index + */ + public int getCurrentCallIndex() { + return currentCallIndex; + } + + /** + * Tries to get a cached CallRecord for recovery. + * + * <p>This method is called by Python's execute/execute_async to check if a previous result + * exists. If found and validated, the cached result is returned. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @return array containing [isHit (boolean), resultPayload (byte[]), exceptionPayload + * (byte[])], or null if miss + */ + public Object[] tryGetCachedCallRecord(String functionId, String argsDigest) { + mailboxThreadChecker.run(); + + if (currentCallIndex < recoveryCallRecords.size()) { + CallRecord record = recoveryCallRecords.get(currentCallIndex); + + if (record.matches(functionId, argsDigest)) { + LOG.debug( + "CallRecord hit at index {}: functionId={}, argsDigest={}", + currentCallIndex, + functionId, + argsDigest); + currentCallIndex++; + return new Object[] {true, record.getResultPayload(), record.getExceptionPayload()}; + } else { + // Non-deterministic call detected, clear subsequent records + LOG.warn( + "Non-deterministic call detected at index {}: expected functionId={}, argsDigest={}, " + + "but got functionId={}, argsDigest={}. Clearing subsequent records.", + currentCallIndex, + record.getFunctionId(), + record.getArgsDigest(), + functionId, + argsDigest); + clearCallRecordsFromCurrentIndex(); + } + } + + return null; + } + + /** + * Records a completed call and persists the ActionState. + * + * <p>This method is called by Python after each execute/execute_async call completes. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized result (null if exception) + * @param exceptionPayload the serialized exception (null if success) + */ + public void recordCallCompletion( + String functionId, String argsDigest, byte[] resultPayload, byte[] exceptionPayload) { + mailboxThreadChecker.run(); + + CallRecord callRecord = + new CallRecord(functionId, argsDigest, resultPayload, exceptionPayload); + + if (currentActionState != null && actionStatePersister != null) { + currentActionState.addCallRecord(callRecord); + actionStatePersister.run(); + LOG.debug( + "Recorded and persisted CallRecord at index {}: functionId={}, argsDigest={}", + currentCallIndex, + functionId, + argsDigest); + } Review Comment: Why would `currentActionState` and `actionStatePersister` be `null` when this method is called? Shall we use assertion here? ########## runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java: ########## @@ -48,4 +78,120 @@ public void sendEvent(String type, byte[] event, String eventString) { // this method will be invoked by PythonActionExecutor's python interpreter. sendEvent(new PythonEvent(event, type, eventString)); } + + /** + * Initializes the recovery state for the current action. + * + * @param actionState the ActionState for the current action (used for adding CallRecords) + * @param actionStatePersister callback to persist ActionState after each code block completion + */ + public void initRecoveryState(ActionState actionState, Runnable actionStatePersister) { + this.currentCallIndex = 0; + this.currentActionState = actionState; + this.actionStatePersister = actionStatePersister; + this.recoveryCallRecords = + actionState != null && actionState.getCallRecords() != null + ? new ArrayList<>(actionState.getCallRecords()) + : new ArrayList<>(); + } + + /** + * Gets the current call index. + * + * @return the current call index + */ + public int getCurrentCallIndex() { + return currentCallIndex; + } + + /** + * Tries to get a cached CallRecord for recovery. + * + * <p>This method is called by Python's execute/execute_async to check if a previous result + * exists. If found and validated, the cached result is returned. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @return array containing [isHit (boolean), resultPayload (byte[]), exceptionPayload + * (byte[])], or null if miss + */ + public Object[] tryGetCachedCallRecord(String functionId, String argsDigest) { Review Comment: I'd suggest the name `matchNextOrClearSubsequentCallRecord`. Otherwise, the clearing is super implicit. ########## runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java: ########## @@ -20,21 +20,51 @@ import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.CallRecord; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.agents.runtime.python.event.PythonEvent; import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; -/** A specialized {@link RunnerContext} that is specifically used when executing Python actions. */ +import java.util.ArrayList; +import java.util.List; + +/** + * A specialized {@link RunnerContext} that is specifically used when executing Python actions. + * + * <p>This class provides fine-grained durable execution support by managing {@link CallRecord}s + * that store the results of {@code execute} and {@code execute_async} calls. During recovery, these + * records are used to skip re-execution of already completed calls. + */ @NotThreadSafe public class PythonRunnerContextImpl extends RunnerContextImpl { + + private static final Logger LOG = LoggerFactory.getLogger(PythonRunnerContextImpl.class); + + /** Current call index within the action, used for matching CallRecords during recovery. */ + private int currentCallIndex = 0; + + /** List of existing CallRecords loaded during recovery. */ + private List<CallRecord> recoveryCallRecords; + + /** The current ActionState being built during action execution. */ + @Nullable private ActionState currentActionState; + + /** Callback to persist ActionState after each code block completion. */ + @Nullable private Runnable actionStatePersister; Review Comment: It's kind of weird that these are maintained in `PythonRunnerContextImpl`. 1. These are not only needed by python. We will also support this for Java. 2. These are actually ActionTask-specific. I think we should handle this in `RunnerContextImpl.switchActionContext()`. We might want to actually introduce an `ActionTaskContext`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
