This is an automated email from the ASF dual-hosted git repository. sxnan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 53d76a75b0c7c0d34dbee902303311e3a57a8d7d Author: sxnan <[email protected]> AuthorDate: Thu Jan 8 18:00:37 2026 +0800 [runtime] Implement CallRecord persistent and restore --- .../flink_agents/runtime/flink_runner_context.py | 282 ++++++++++++++++++++- .../runtime/tests/test_durable_execution.py | 150 +++++++++++ .../runtime/context/ActionStatePersister.java | 44 ++++ .../agents/runtime/context/RunnerContextImpl.java | 188 ++++++++++++++ .../runtime/operator/ActionExecutionOperator.java | 101 +++++++- .../python/context/PythonRunnerContextImpl.java | 1 + .../context/DurableExecutionContextTest.java | 206 +++++++++++++++ 7 files changed, 954 insertions(+), 18 deletions(-) diff --git a/python/flink_agents/runtime/flink_runner_context.py b/python/flink_agents/runtime/flink_runner_context.py index 257d22d4..dffdaf73 100644 --- a/python/flink_agents/runtime/flink_runner_context.py +++ b/python/flink_agents/runtime/flink_runner_context.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# +import hashlib +import logging import os from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict @@ -39,11 +41,136 @@ from flink_agents.runtime.memory.vector_store_long_term_memory import ( VectorStoreLongTermMemory, ) +logger = logging.getLogger(__name__) + + +class _DurableExecutionResult: + """Wrapper that holds result and triggers recording when unwrapped.""" + + def __init__( + self, + func: Callable, + args: tuple, + kwargs: dict, + result: Any, + record_callback: Callable, + ) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + self.result = result + self.record_callback = record_callback + self._recorded = False + + def get_result(self) -> Any: + """Get the result and record completion if not already recorded.""" + if not self._recorded: + self.record_callback(self.func, self.args, self.kwargs, self.result, None) + self._recorded = True + return self.result + + +class _DurableExecutionException(Exception): + """Wrapper exception that holds exception info and triggers recording.""" + + def __init__( + self, + func: Callable, + args: tuple, + kwargs: dict, + result: Any, + exception: BaseException, + record_callback: Callable, + ) -> None: + super().__init__(str(exception)) + self.func = func + self.args = args + self.kwargs = kwargs + self.original_exception = exception + self.record_callback = record_callback + self._recorded = False + + def record_and_raise(self) -> None: + """Record completion and raise the original exception.""" + if not self._recorded: + self.record_callback( + self.func, self.args, self.kwargs, None, self.original_exception + ) + self._recorded = True + raise self.original_exception from None + + +class _CachedAsyncExecutionResult(AsyncExecutionResult): + """An AsyncExecutionResult that returns a cached value immediately.""" + + def __init__(self, cached_result: Any) -> None: + # Don't call super().__init__ as we don't need executor/func/args/kwargs + self._cached_result = cached_result + + def __await__(self) -> Any: + """Return the cached result immediately. + + This is a generator that yields nothing and returns the cached result. + """ + if False: + yield # Make this a generator function + return self._cached_result + + +class _DurableAsyncExecutionResult(AsyncExecutionResult): + """An AsyncExecutionResult that records completion after execution.""" + + def __init__( + self, executor: Any, func: Callable, args: tuple, kwargs: dict + ) -> None: + super().__init__(executor, func, args, kwargs) + + def __await__(self) -> Any: + """Execute and record completion when awaited.""" + future = self._executor.submit(self._func, *self._args, **self._kwargs) + while not future.done(): + yield + + result = future.result() + + # Handle the wrapped result/exception + if isinstance(result, _DurableExecutionResult): + return result.get_result() + elif isinstance(result, _DurableExecutionException): + result.record_and_raise() + else: + return result + + +def _compute_function_id(func: Callable) -> str: + """Compute a stable function identifier from a callable. + + Returns module.qualname for functions/methods. + """ + module = getattr(func, "__module__", "<unknown>") + qualname = getattr(func, "__qualname__", getattr(func, "__name__", "<unknown>")) + return f"{module}.{qualname}" + + +def _compute_args_digest(args: tuple, kwargs: dict) -> str: + """Compute a stable digest of the serialized arguments. + + The digest is used to validate that the same arguments are passed + during recovery as during the original execution. + """ + try: + serialized = cloudpickle.dumps((args, kwargs)) + return hashlib.sha256(serialized).hexdigest()[:16] + except Exception: + # If serialization fails, return a fallback digest + return hashlib.sha256(str((args, kwargs)).encode()).hexdigest()[:16] + class FlinkRunnerContext(RunnerContext): """Providing context for agent execution in Flink Environment. - This context allows access to event handling. + This context allows access to event handling and provides fine-grained + durable execution support through execute() and execute_async() methods. """ __agent_plan: AgentPlan | None @@ -185,34 +312,167 @@ class FlinkRunnerContext(RunnerContext): """ return FlinkMetricGroup(self._j_runner_context.getActionMetricGroup()) + def _try_get_cached_result( + self, func: Callable, args: tuple, kwargs: dict + ) -> tuple[bool, Any]: + """Try to get a cached result from a previous execution. + + Returns: + ------- + tuple[bool, Any] + A tuple of (is_hit, result_or_exception). If is_hit is True, + the second element is the cached result or an exception to re-raise. + """ + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + + cached_exception: BaseException | None = None + try: + cached = self._j_runner_context.matchNextOrClearSubsequentCallResult( + function_id, args_digest + ) + if cached is not None: + is_hit, result_payload, exception_payload = cached + if is_hit: + if exception_payload is not None: + # Store cached exception to re-raise outside try block + cached_exception = cloudpickle.loads(bytes(exception_payload)) + elif result_payload is not None: + return True, cloudpickle.loads(bytes(result_payload)) + else: + return True, None + except Exception as e: + # If Java method doesn't exist (not supported), fall through to execute + if "matchNextOrClearSubsequentCallResult" in str(e): + logger.debug("Durable execution not supported, executing directly") + else: + raise + + # Re-raise cached exception outside try block + if cached_exception is not None: + raise cached_exception + + return False, None + + def _record_call_completion( + self, + func: Callable, + args: tuple, + kwargs: dict, + result: Any, + exception: BaseException | None, + ) -> None: + """Record the completion of a call for durable execution. + + Parameters + ---------- + func : Callable + The function that was executed. + args : tuple + Positional arguments passed to the function. + kwargs : dict + Keyword arguments passed to the function. + result : Any + The result of the function (None if exception occurred). + exception : BaseException | None + The exception raised by the function (None if successful). + """ + function_id = _compute_function_id(func) + args_digest = _compute_args_digest(args, kwargs) + + try: + result_payload = None if exception else cloudpickle.dumps(result) + exception_payload = cloudpickle.dumps(exception) if exception else None + + self._j_runner_context.recordCallCompletion( + function_id, args_digest, result_payload, exception_payload + ) + except Exception as e: + # If Java method doesn't exist, silently ignore + if "recordCallCompletion" not in str(e): + logger.warning("Failed to record call completion: %s", e) + @override - def execute( + def durable_execute( self, func: Callable[[Any], Any], *args: Any, **kwargs: Any, ) -> Any: - """Synchronously execute the provided function. Access to memory - is prohibited within the function. + """Synchronously execute the provided function with durable execution support. + Access to memory is prohibited within the function. + + The result of the function will be stored and returned when the same + durable_execute call is made again during job recovery. The arguments and the + result must be serializable. The function is executed synchronously in the current thread, blocking the operator until completion. """ - # TODO: Add durable execution support (persist result for recovery) - return func(*args, **kwargs) + # Try to get cached result for recovery + is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) + if is_hit: + return cached_result + + # Execute the function + exception = None + result = None + try: + result = func(*args, **kwargs) + except BaseException as e: + exception = e + + # Record the completion + self._record_call_completion(func, args, kwargs, result, exception) + + if exception: + raise exception + return result @override - def execute_async( + def durable_execute_async( self, func: Callable[[Any], Any], *args: Any, **kwargs: Any, ) -> AsyncExecutionResult: - """Asynchronously execute the provided function. Access to memory - is prohibited within the function. + """Asynchronously execute the provided function with durable execution support. + Access to memory is prohibited within the function. + + The result of the function will be stored and returned when the same + durable_execute_async call is made again during job recovery. The arguments + and the result must be serializable. + + Important: The result is only recorded when the returned AsyncExecutionResult + is awaited. Fire-and-forget calls (not awaiting the result) will NOT be + recorded and cannot be recovered. """ - # TODO: Add durable execution support (persist result for recovery) - return AsyncExecutionResult(self.executor, func, args, kwargs) + # Try to get cached result for recovery + is_hit, cached_result = self._try_get_cached_result(func, args, kwargs) + if is_hit: + # Return a pre-completed AsyncExecutionResult + return _CachedAsyncExecutionResult(cached_result) + + # Create a wrapper function that records completion + def wrapped_func(*a: Any, **kw: Any) -> Any: + exception = None + result = None + try: + result = func(*a, **kw) + except BaseException as e: + exception = e + + # Note: This runs in a thread pool, so we need to be careful + # The actual recording will happen when the result is awaited + if exception: + raise _DurableExecutionException( + func, args, kwargs, result, exception, self._record_call_completion + ) + return _DurableExecutionResult( + func, args, kwargs, result, self._record_call_completion + ) + + return _DurableAsyncExecutionResult(self.executor, wrapped_func, args, kwargs) @property @override diff --git a/python/flink_agents/runtime/tests/test_durable_execution.py b/python/flink_agents/runtime/tests/test_durable_execution.py new file mode 100644 index 00000000..e59e54cd --- /dev/null +++ b/python/flink_agents/runtime/tests/test_durable_execution.py @@ -0,0 +1,150 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Tests for durable execution helper functions.""" + +import cloudpickle + +from flink_agents.runtime.flink_runner_context import ( + _compute_args_digest, + _compute_function_id, +) + + +def sample_function(x: int, y: int) -> int: + """A sample function for testing.""" + return x + y + + +class SampleClass: + """A sample class for testing method function IDs.""" + + def instance_method(self, x: int) -> int: + """An instance method.""" + return x * 2 + + @staticmethod + def static_method(x: int) -> int: + """A static method.""" + return x * 3 + + @classmethod + def class_method(cls, x: int) -> int: + """A class method.""" + return x * 4 + + +def test_compute_function_id_for_function() -> None: + """Test function ID computation for regular functions.""" + func_id = _compute_function_id(sample_function) + assert "sample_function" in func_id + assert "test_durable_execution" in func_id + + +def test_compute_function_id_for_lambda() -> None: + """Test function ID computation for lambda functions.""" + lambda_func = lambda x: x + 1 # noqa: E731 + func_id = _compute_function_id(lambda_func) + assert "<lambda>" in func_id + + +def test_compute_function_id_for_method() -> None: + """Test function ID computation for instance methods.""" + obj = SampleClass() + func_id = _compute_function_id(obj.instance_method) + assert "instance_method" in func_id + assert "SampleClass" in func_id + + +def test_compute_function_id_for_static_method() -> None: + """Test function ID computation for static methods.""" + func_id = _compute_function_id(SampleClass.static_method) + assert "static_method" in func_id + + +def test_compute_function_id_for_class_method() -> None: + """Test function ID computation for class methods.""" + func_id = _compute_function_id(SampleClass.class_method) + assert "class_method" in func_id + + +def test_compute_args_digest_basic() -> None: + """Test args digest computation for basic types.""" + digest1 = _compute_args_digest((1, 2), {"key": "value"}) + digest2 = _compute_args_digest((1, 2), {"key": "value"}) + # Same arguments should produce same digest + assert digest1 == digest2 + + # Different arguments should produce different digest + digest3 = _compute_args_digest((1, 3), {"key": "value"}) + assert digest1 != digest3 + + +def test_compute_args_digest_empty() -> None: + """Test args digest computation for empty arguments.""" + digest = _compute_args_digest((), {}) + assert len(digest) == 16 # SHA256 truncated to 16 chars + + +def test_compute_args_digest_complex_types() -> None: + """Test args digest computation for complex types.""" + complex_args = ( + {"nested": {"key": [1, 2, 3]}}, + [1, 2, {"inner": "value"}], + ) + complex_kwargs = {"data": {"x": 1, "y": 2}} + + digest1 = _compute_args_digest(complex_args, complex_kwargs) + digest2 = _compute_args_digest(complex_args, complex_kwargs) + assert digest1 == digest2 + + +def test_compute_args_digest_order_matters() -> None: + """Test that argument order affects the digest.""" + digest1 = _compute_args_digest((1, 2), {}) + digest2 = _compute_args_digest((2, 1), {}) + assert digest1 != digest2 + + +def test_compute_args_digest_kwargs_vs_args() -> None: + """Test that kwargs and args produce different digests.""" + digest1 = _compute_args_digest((1,), {"y": 2}) + digest2 = _compute_args_digest((1, 2), {}) + assert digest1 != digest2 + + +def test_cloudpickle_serialization() -> None: + """Test that results can be serialized and deserialized with cloudpickle.""" + # Test basic types + original = {"key": "value", "number": 42, "list": [1, 2, 3]} + serialized = cloudpickle.dumps(original) + deserialized = cloudpickle.loads(serialized) + assert deserialized == original + + # Test exception + def raise_test_error() -> None: + error_message = "test error" + raise ValueError(error_message) + + try: + raise_test_error() + except ValueError as e: + serialized_exc = cloudpickle.dumps(e) + deserialized_exc = cloudpickle.loads(serialized_exc) + assert str(deserialized_exc) == "test error" + assert isinstance(deserialized_exc, ValueError) + diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java new file mode 100644 index 00000000..529098a8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/ActionStatePersister.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.context; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; + +/** + * Interface for persisting {@link ActionState}. + * + * <p>This interface decouples the {@link RunnerContextImpl.DurableExecutionContext} from the + * storage layer. + */ +public interface ActionStatePersister { + + /** + * Persists the given ActionState. + * + * @param key the key for the action + * @param sequenceNumber the sequence number for ordering + * @param action the action being executed + * @param event the event that triggered the action + * @param actionState the ActionState to persist + */ + void persist( + Object key, long sequenceNumber, Action action, Event event, ActionState actionState); +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 2b8d8dc5..8a946f2d 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -28,13 +28,20 @@ import org.apache.flink.agents.api.memory.LongTermMemoryOptions; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.plan.actions.Action; import org.apache.flink.agents.plan.utils.JsonUtils; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.CallResult; import org.apache.flink.agents.runtime.memory.CachedMemoryStore; import org.apache.flink.agents.runtime.memory.InteranlBaseLongTermMemory; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; import org.apache.flink.agents.runtime.memory.VectorStoreLongTermMemory; import org.apache.flink.agents.runtime.metrics.FlinkAgentsMetricGroupImpl; import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.LinkedList; @@ -77,6 +84,8 @@ public class RunnerContextImpl implements RunnerContext { } } + private static final Logger LOG = LoggerFactory.getLogger(RunnerContextImpl.class); + protected final List<Event> pendingEvents = new ArrayList<>(); protected final FlinkAgentsMetricGroupImpl agentMetricGroup; protected final Runnable mailboxThreadChecker; @@ -86,6 +95,9 @@ public class RunnerContextImpl implements RunnerContext { protected String actionName; protected InteranlBaseLongTermMemory ltm; + /** Context for fine-grained durable execution, may be null if not enabled. */ + @Nullable protected DurableExecutionContext durableExecutionContext; + public RunnerContextImpl( FlinkAgentsMetricGroupImpl agentMetricGroup, Runnable mailboxThreadChecker, @@ -247,4 +259,180 @@ public class RunnerContextImpl implements RunnerContext { public void clearSensoryMemory() throws Exception { memoryContext.getSensoryMemStore().clear(); } + + public void setDurableExecutionContext( + @Nullable DurableExecutionContext durableExecutionContext) { + this.durableExecutionContext = durableExecutionContext; + } + + @Nullable + public DurableExecutionContext getDurableExecutionContext() { + return durableExecutionContext; + } + + public void clearDurableExecutionContext() { + this.durableExecutionContext = null; + } + + /** + * Matches the next call result for recovery, or clears subsequent results if mismatch detected. + * + * <p>This method delegates to the {@link DurableExecutionContext} if present. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @return array containing [isHit (boolean), resultPayload (byte[]), exceptionPayload + * (byte[])], or null if miss or durable execution is not enabled + */ + public Object[] matchNextOrClearSubsequentCallResult(String functionId, String argsDigest) { + mailboxThreadChecker.run(); + if (durableExecutionContext != null) { + return durableExecutionContext.matchNextOrClearSubsequentCallResult( + functionId, argsDigest); + } + return null; + } + + /** + * Records a completed call and persists the ActionState. + * + * <p>This method delegates to the {@link DurableExecutionContext} if present. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized result (null if exception) + * @param exceptionPayload the serialized exception (null if success) + */ + public void recordCallCompletion( + String functionId, String argsDigest, byte[] resultPayload, byte[] exceptionPayload) { + mailboxThreadChecker.run(); + if (durableExecutionContext != null) { + durableExecutionContext.recordCallCompletion( + functionId, argsDigest, resultPayload, exceptionPayload); + } + } + + /** + * Context for fine-grained durable execution within an action. + * + * <p>This class encapsulates all state needed for {@code durable_execute}/{@code + * durable_execute_async} recovery. During normal execution, each call is recorded as a {@link + * CallResult}. During recovery, these results are used to skip re-execution of already + * completed calls. + */ + public static class DurableExecutionContext { + private final Object key; + private final long sequenceNumber; + private final Action action; + private final Event event; + private final ActionState actionState; + private final ActionStatePersister persister; + + /** Current call index within the action, used for matching CallResults during recovery. */ + private int currentCallIndex; + + /** Snapshot of CallResults loaded during recovery. */ + private List<CallResult> recoveryCallResults; + + public DurableExecutionContext( + Object key, + long sequenceNumber, + Action action, + Event event, + ActionState actionState, + ActionStatePersister persister) { + this.key = key; + this.sequenceNumber = sequenceNumber; + this.action = action; + this.event = event; + this.actionState = actionState; + this.persister = persister; + this.currentCallIndex = 0; + this.recoveryCallResults = + actionState.getCallResults() != null + ? new ArrayList<>(actionState.getCallResults()) + : new ArrayList<>(); + } + + public int getCurrentCallIndex() { + return currentCallIndex; + } + + public ActionState getActionState() { + return actionState; + } + + /** + * Matches the next call result for recovery, or clears subsequent results if mismatch + * detected. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @return array containing [isHit, resultPayload, exceptionPayload], or null if miss + */ + public Object[] matchNextOrClearSubsequentCallResult(String functionId, String argsDigest) { + if (currentCallIndex < recoveryCallResults.size()) { + CallResult result = recoveryCallResults.get(currentCallIndex); + + if (result.matches(functionId, argsDigest)) { + LOG.debug( + "CallResult hit at index {}: functionId={}, argsDigest={}", + currentCallIndex, + functionId, + argsDigest); + currentCallIndex++; + return new Object[] { + true, result.getResultPayload(), result.getExceptionPayload() + }; + } else { + LOG.warn( + "Non-deterministic call detected at index {}: expected functionId={}, " + + "argsDigest={}, but got functionId={}, argsDigest={}. " + + "Clearing subsequent results.", + currentCallIndex, + result.getFunctionId(), + result.getArgsDigest(), + functionId, + argsDigest); + clearCallResultsFromCurrentIndex(); + } + } + return null; + } + + /** + * Records a completed call and persists the ActionState. + * + * @param functionId the function identifier + * @param argsDigest the digest of serialized arguments + * @param resultPayload the serialized result (null if exception) + * @param exceptionPayload the serialized exception (null if success) + */ + public void recordCallCompletion( + String functionId, + String argsDigest, + byte[] resultPayload, + byte[] exceptionPayload) { + CallResult callResult = + new CallResult(functionId, argsDigest, resultPayload, exceptionPayload); + + actionState.addCallResult(callResult); + persister.persist(key, sequenceNumber, action, event, actionState); + + LOG.debug( + "Recorded and persisted CallResult at index {}: functionId={}, argsDigest={}", + currentCallIndex, + functionId, + argsDigest); + + currentCallIndex++; + } + + private void clearCallResultsFromCurrentIndex() { + actionState.clearCallResultsFrom(currentCallIndex); + recoveryCallResults = + recoveryCallResults.subList( + 0, Math.min(currentCallIndex, recoveryCallResults.size())); + } + } } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 1b569ac8..60b7e329 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -38,6 +38,7 @@ import org.apache.flink.agents.plan.resourceprovider.PythonResourceProvider; import org.apache.flink.agents.runtime.actionstate.ActionState; import org.apache.flink.agents.runtime.actionstate.ActionStateStore; import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; +import org.apache.flink.agents.runtime.context.ActionStatePersister; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.env.EmbeddedPythonEnvironment; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; @@ -110,7 +111,7 @@ import static org.apache.flink.util.Preconditions.checkState; * and the resulting output event is collected for further processing. */ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT> - implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { + implements OneInputStreamOperator<IN, OUT>, BoundedOneInput, ActionStatePersister { private static final long serialVersionUID = 1L; @@ -190,6 +191,11 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private final transient Map<ActionTask, RunnerContextImpl.MemoryContext> actionTaskMemoryContexts; + // This in memory map keeps track of the durable execution context for async action tasks + // that have not been finished, allowing recovery of currentCallIndex across invocations + private final transient Map<ActionTask, RunnerContextImpl.DurableExecutionContext> + actionTaskDurableContexts; + // Each job can only have one identifier and this identifier must be consistent across restarts. // We cannot use job id as the identifier here because user may change job id by // creating a savepoint, stop the job and then resume from savepoint. @@ -212,6 +218,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT this.actionStateStore = actionStateStore; this.checkpointIdToSeqNums = new HashMap<>(); this.actionTaskMemoryContexts = new HashMap<>(); + this.actionTaskDurableContexts = new HashMap<>(); OperatorUtils.setChainStrategy(this, ChainingStrategy.ALWAYS); } @@ -446,7 +453,10 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT Optional<ActionTask> generatedActionTaskOpt = Optional.empty(); ActionState actionState = maybeGetActionState(key, sequenceNumber, actionTask.action, actionTask.event); - if (actionState != null) { + + // Check if action is already completed + if (actionState != null && actionState.isCompleted()) { + // Action has completed, skip execution and replay memory/events isFinished = true; outputEvents = actionState.getOutputEvents(); for (MemoryUpdate memoryUpdate : actionState.getShortTermMemoryUpdates()) { @@ -463,16 +473,27 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT .set(memoryUpdate.getPath(), memoryUpdate.getValue()); } } else { - maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + // Initialize ActionState if not exists, or use existing one for recovery + if (actionState == null) { + maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + actionState = + maybeGetActionState( + key, sequenceNumber, actionTask.action, actionTask.event); + } + + // Set up durable execution context for fine-grained recovery + setupDurableExecutionContext(actionTask, actionState); + ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke( getRuntimeContext().getUserCodeClassLoader(), this.pythonActionExecutor); - // We remove the RunnerContext of the action task from the map after it is finished. The - // RunnerContext will be added later if the action task has a generated action task, - // meaning it is not finished. + // We remove the contexts from the map after the task is processed. They will be added + // back later if the action task has a generated action task, meaning it is not + // finished. actionTaskMemoryContexts.remove(actionTask); + actionTaskDurableContexts.remove(actionTask); maybePersistTaskResult( key, sequenceNumber, @@ -505,10 +526,15 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // execution. ActionTask generatedActionTask = generatedActionTaskOpt.get(); - // If the action task is not finished, we keep the runner context in the memory for the + // If the action task is not finished, we keep the contexts in memory for the // next generated ActionTask to be invoked. actionTaskMemoryContexts.put( generatedActionTask, actionTask.getRunnerContext().getMemoryContext()); + RunnerContextImpl.DurableExecutionContext durableContext = + actionTask.getRunnerContext().getDurableExecutionContext(); + if (durableContext != null) { + actionTaskDurableContexts.put(generatedActionTask, durableContext); + } actionTasksKState.add(generatedActionTask); } @@ -916,7 +942,68 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT for (Event outputEvent : actionTaskResult.getOutputEvents()) { actionState.addEvent(outputEvent); } + + // Mark the action as completed and clear call records + // This indicates that recovery should skip the entire action + actionState.markCompleted(); + actionStateStore.put(key, sequenceNum, action, event, actionState); + + // Clear durable execution context + context.clearDurableExecutionContext(); + } + + /** + * Sets up the durable execution context for fine-grained recovery. + * + * <p>This method initializes the runner context with a {@link + * RunnerContextImpl.DurableExecutionContext}, which enables execute/execute_async calls to: + * + * <ul> + * <li>Skip re-execution for already completed calls during recovery + * <li>Persist CallRecords after each code block completion + * </ul> + */ + private void setupDurableExecutionContext(ActionTask actionTask, ActionState actionState) { + if (actionStateStore == null) { + return; + } + + RunnerContextImpl.DurableExecutionContext durableContext; + if (actionTaskDurableContexts.containsKey(actionTask)) { + // Reuse existing context for async action continuation + durableContext = actionTaskDurableContexts.get(actionTask); + } else { + // Create new context for first invocation + final long sequenceNumber; + try { + sequenceNumber = sequenceNumberKState.value(); + } catch (Exception e) { + throw new RuntimeException("Failed to get sequence number from state", e); + } + + durableContext = + new RunnerContextImpl.DurableExecutionContext( + actionTask.getKey(), + sequenceNumber, + actionTask.action, + actionTask.event, + actionState, + this); + } + + actionTask.getRunnerContext().setDurableExecutionContext(durableContext); + } + + @Override + public void persist( + Object key, long sequenceNumber, Action action, Event event, ActionState actionState) { + try { + actionStateStore.put(key, sequenceNumber, action, event, actionState); + } catch (Exception e) { + LOG.error("Failed to persist ActionState", e); + throw new RuntimeException("Failed to persist ActionState", e); + } } private void maybePruneState(Object key, long sequenceNum) throws Exception { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java index 7df56e5e..ddabf503 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/context/PythonRunnerContextImpl.java @@ -15,6 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.flink.agents.runtime.python.context; import org.apache.flink.agents.api.Event; diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java new file mode 100644 index 00000000..f2701e50 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/context/DurableExecutionContextTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.runtime.context; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.actions.Action; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.CallResult; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.mock; + +/** Unit tests for {@link RunnerContextImpl.DurableExecutionContext}. */ +class DurableExecutionContextTest { + + private ActionState actionState; + private AtomicInteger persistCallCount; + private ActionState lastPersistedState; + private Object testKey; + private long testSequenceNumber; + private Action mockAction; + private Event mockEvent; + + @BeforeEach + void setUp() { + actionState = new ActionState(null); + persistCallCount = new AtomicInteger(0); + lastPersistedState = null; + testKey = "testKey"; + testSequenceNumber = 1L; + mockAction = mock(Action.class); + mockEvent = mock(Event.class); + } + + private RunnerContextImpl.DurableExecutionContext createContext() { + ActionStatePersister persister = + (key, seqNum, action, event, state) -> { + persistCallCount.incrementAndGet(); + lastPersistedState = state; + }; + return new RunnerContextImpl.DurableExecutionContext( + testKey, testSequenceNumber, mockAction, mockEvent, actionState, persister); + } + + @Test + void testInitialization() { + actionState.addCallResult( + new CallResult("funcA", "digestA", "resultA".getBytes(StandardCharsets.UTF_8))); + actionState.addCallResult( + new CallResult("funcB", "digestB", "resultB".getBytes(StandardCharsets.UTF_8))); + + RunnerContextImpl.DurableExecutionContext context = createContext(); + + assertEquals(0, context.getCurrentCallIndex()); + assertSame(actionState, context.getActionState()); + } + + @Test + void testMatchNextOrClearSubsequentCallResultHit() { + byte[] expectedResult = "cached_result".getBytes(StandardCharsets.UTF_8); + actionState.addCallResult(new CallResult("funcA", "digestA", expectedResult)); + + RunnerContextImpl.DurableExecutionContext context = createContext(); + + Object[] result = context.matchNextOrClearSubsequentCallResult("funcA", "digestA"); + + assertNotNull(result); + assertEquals(3, result.length); + assertTrue((Boolean) result[0]); // isHit + assertArrayEquals(expectedResult, (byte[]) result[1]); // resultPayload + assertNull(result[2]); // exceptionPayload + assertEquals(1, context.getCurrentCallIndex()); + } + + @Test + void testMatchNextOrClearSubsequentCallResultMiss() { + RunnerContextImpl.DurableExecutionContext context = createContext(); + + Object[] result = context.matchNextOrClearSubsequentCallResult("funcA", "digestA"); + + assertNull(result); + assertEquals(0, context.getCurrentCallIndex()); + } + + @Test + void testMatchNextOrClearSubsequentCallResultMismatch() { + actionState.addCallResult(new CallResult("funcA", "digestA", "result".getBytes())); + actionState.addCallResult(new CallResult("funcB", "digestB", "result".getBytes())); + + RunnerContextImpl.DurableExecutionContext context = createContext(); + + // Call with mismatched functionId - should clear subsequent results and return null + Object[] result = context.matchNextOrClearSubsequentCallResult("funcX", "digestX"); + + assertNull(result); + // ActionState should have results cleared from index 0 + assertEquals(0, actionState.getCallResultCount()); + // Persist is not called here - it will be called in recordCallCompletion + assertEquals(0, persistCallCount.get()); + } + + @Test + void testRecordCallCompletionSuccess() { + RunnerContextImpl.DurableExecutionContext context = createContext(); + + byte[] resultPayload = "success_result".getBytes(StandardCharsets.UTF_8); + context.recordCallCompletion("funcA", "digestA", resultPayload, null); + + assertEquals(1, context.getCurrentCallIndex()); + assertEquals(1, actionState.getCallResults().size()); + assertEquals("funcA", actionState.getCallResults().get(0).getFunctionId()); + // Verify persister was called + assertEquals(1, persistCallCount.get()); + assertSame(actionState, lastPersistedState); + } + + @Test + void testRecordCallCompletionException() { + RunnerContextImpl.DurableExecutionContext context = createContext(); + + byte[] exceptionPayload = "exception_data".getBytes(StandardCharsets.UTF_8); + context.recordCallCompletion("funcA", "digestA", null, exceptionPayload); + + assertEquals(1, context.getCurrentCallIndex()); + CallResult recorded = actionState.getCallResults().get(0); + assertNull(recorded.getResultPayload()); + assertArrayEquals(exceptionPayload, recorded.getExceptionPayload()); + assertEquals(1, persistCallCount.get()); + } + + @Test + void testMultipleCallResultRecovery() { + byte[] result1 = "result1".getBytes(StandardCharsets.UTF_8); + byte[] result2 = "result2".getBytes(StandardCharsets.UTF_8); + actionState.addCallResult(new CallResult("func1", "digest1", result1)); + actionState.addCallResult(new CallResult("func2", "digest2", result2)); + + RunnerContextImpl.DurableExecutionContext context = createContext(); + + // First call should hit + Object[] hit1 = context.matchNextOrClearSubsequentCallResult("func1", "digest1"); + assertNotNull(hit1); + assertTrue((Boolean) hit1[0]); + assertArrayEquals(result1, (byte[]) hit1[1]); + + // Second call should hit + Object[] hit2 = context.matchNextOrClearSubsequentCallResult("func2", "digest2"); + assertNotNull(hit2); + assertTrue((Boolean) hit2[0]); + assertArrayEquals(result2, (byte[]) hit2[1]); + + // Third call should miss (no more results) + Object[] miss = context.matchNextOrClearSubsequentCallResult("func3", "digest3"); + assertNull(miss); + } + + @Test + void testRecoveryWithExceptionPayload() { + byte[] exceptionPayload = "exception_data".getBytes(StandardCharsets.UTF_8); + actionState.addCallResult(CallResult.ofException("funcA", "digestA", exceptionPayload)); + + RunnerContextImpl.DurableExecutionContext context = createContext(); + + Object[] result = context.matchNextOrClearSubsequentCallResult("funcA", "digestA"); + + assertNotNull(result); + assertTrue((Boolean) result[0]); // isHit + assertNull(result[1]); // resultPayload should be null + assertArrayEquals(exceptionPayload, (byte[]) result[2]); // exceptionPayload + } + + @Test + void testMultiplePersistCalls() { + RunnerContextImpl.DurableExecutionContext context = createContext(); + + // Record multiple completions + context.recordCallCompletion("func1", "digest1", "result1".getBytes(), null); + context.recordCallCompletion("func2", "digest2", "result2".getBytes(), null); + context.recordCallCompletion("func3", "digest3", "result3".getBytes(), null); + + // Each call should trigger persistence + assertEquals(3, persistCallCount.get()); + assertEquals(3, actionState.getCallResults().size()); + } +}
