This is an automated email from the ASF dual-hosted git repository.

vatsrahul1001 pushed a commit to branch backport-173c2a1-v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 65504bd1d4fac592bdfc1e3ddfd1d46f9ce8d957
Author: Jarek Potiuk <[email protected]>
AuthorDate: Tue May 19 20:35:42 2026 +0200

    Recover stuck TIs when direct terminal-state API call fails (#66574)
    
    * Recover stuck TIs when direct terminal-state API call fails
    
    The supervisor's _handle_request for SucceedTask, RetryTask, DeferTask,
    and RescheduleTask set _terminal_state BEFORE calling the matching
    client.task_instances.{succeed,retry,defer,reschedule}() API. If that
    API call raised (transient network blip, server 5xx, etc.),
    _terminal_state was set on the supervisor but the server never saw
    the transition. The supervisor's update_task_state_if_needed then
    saw final_state in STATES_SENT_DIRECTLY and short-circuited the
    recovery finish() call -- leaving the TaskInstance stuck RUNNING
    on the server forever, blocking downstream dependencies and
    triggering false alerts.
    
    Two-part fix:
    
    1. Make the direct API call FIRST. Only set _terminal_state and the
       new _terminal_state_synced_to_server flag after the call returns
       successfully. If the API raises, both stay unset and the exception
       propagates to handle_requests, where the existing catch-all sends
       an ErrorResponse to the task subprocess.
    
    2. Have update_task_state_if_needed always call finish() when
       _terminal_state_synced_to_server is False, regardless of what
       final_state happens to return. The finish() API takes the state
       value, so a SUCCESS / DEFERRED / etc. transition that originally
       failed is re-attempted via finish() on subprocess exit.
       Pre-existing semantics for the no-direct-API states (FAILED,
       UP_FOR_RETRY without RetryTask, etc.) preserved -- those land in
       the same finish() branch.
    
    Tests added:
    
    - _terminal_state not set when succeed() raises.
    - update_task_state_if_needed calls finish() when synced flag is
      False, even with final_state == SUCCESS.
    - update_task_state_if_needed skips finish() when synced flag is
      True (preserves the existing happy-path optimisation).
    
    Reported by the L3 ASVS sweep at apache/tooling-agents#24 (FINDING-007).
    
    * Refactor terminal-state dispatch and parametrize tests across all 4 states
    
    Address review feedback on #66574:
    
    - Extract `_send_terminal_state_msg` helper so the per-msg-type dispatch
      for succeed / retry / defer / reschedule lives in one place. Both
      `_handle_request` and `_replay_pending_terminal_state_msg` now go
      through it instead of duplicating the four-branch isinstance chain.
    - Parametrize the two recovery tests over all four terminal-state
      message types (was only Succeed + Defer); add UP_FOR_RETRY and
      UP_FOR_RESCHEDULE coverage.
    
    * Narrow _pending_terminal_state_msg type to satisfy mypy
    
    The field was annotated as BaseModel | None, but _send_terminal_state_msg
    expects SucceedTask | RetryTask | DeferTask | RescheduleTask. mypy
    couldn't prove the narrowing at the _replay_pending_terminal_state_msg
    call site. Tighten the field type to the exact union the setter assigns
    and the consumer accepts.
    
    ---------
    
    Co-authored-by: vatsrahul1001 <[email protected]>
    Co-authored-by: Rahul Vats <[email protected]>
    (cherry picked from commit 173c2a1806dd087272ec287fb923917630ef8f81)
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 110 +++++++++++++----
 .../task_sdk/execution_time/test_supervisor.py     | 131 +++++++++++++++++++++
 2 files changed, 219 insertions(+), 22 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index cd68dc85255..a9b12dc1521 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1089,6 +1089,18 @@ class ActivitySubprocess(WatchedSubprocess):
 
     _terminal_state: str | None = attrs.field(default=None, init=False)
     _final_state: str | None = attrs.field(default=None, init=False)
+    # The terminal-state message currently being processed by 
`_handle_request`,
+    # captured BEFORE the dedicated API call (succeed / retry / defer /
+    # reschedule). If the API call raises (network blip, server 5xx, etc.),
+    # this attribute stays set and the dispatcher in
+    # `update_task_state_if_needed` re-issues the matching API call on
+    # subprocess exit — re-attempting the original transition rather than
+    # falling back to `finish()`, which doesn't accept SUCCESS / DEFERRED /
+    # SERVER_TERMINATED on the server side. Cleared (and `_terminal_state`
+    # set) only after the API call returns successfully.
+    _pending_terminal_state_msg: SucceedTask | RetryTask | DeferTask | 
RescheduleTask | None = attrs.field(
+        default=None, init=False
+    )
 
     _last_successful_heartbeat: float = attrs.field(default=0, init=False)
     _last_heartbeat_attempt: float = attrs.field(default=0, init=False)
@@ -1206,10 +1218,23 @@ class ActivitySubprocess(WatchedSubprocess):
         return self._exit_code
 
     def update_task_state_if_needed(self):
-        # If the process has finished non-directly patched state (directly 
means deferred, reschedule, etc.),
-        # update the state of the TaskInstance to reflect the final state of 
the process.
-        # For states like `deferred`, `up_for_reschedule`, the process will 
exit with 0, but the state will be updated
-        # by the subprocess in the `handle_requests` method.
+        # If a direct-state API call (succeed / retry / defer / reschedule)
+        # was attempted but raised, `_pending_terminal_state_msg` still holds
+        # the original request. Re-issue the matching dedicated API call so
+        # the server learns the terminal state we couldn't deliver earlier.
+        # Without this recovery, a transient API failure during the direct
+        # call would leave the TI stuck RUNNING on the server — `finish()`
+        # cannot substitute because the server-side `finish` endpoint does
+        # not accept SUCCESS / DEFERRED / SERVER_TERMINATED transitions.
+        if self._pending_terminal_state_msg is not None:
+            self._replay_pending_terminal_state_msg()
+            return
+
+        # If the process has finished a non-directly-patched state (e.g.
+        # FAILED, UP_FOR_RETRY without RetryTask), `finish()` is the
+        # dedicated endpoint for those transitions. For states already in
+        # STATES_SENT_DIRECTLY whose direct API call succeeded, no further
+        # action is needed.
         if self.final_state not in STATES_SENT_DIRECTLY:
             self.client.task_instances.finish(
                 id=self.id,
@@ -1218,6 +1243,58 @@ class ActivitySubprocess(WatchedSubprocess):
                 rendered_map_index=self._rendered_map_index,
             )
 
+    def _send_terminal_state_msg(self, msg: SucceedTask | RetryTask | 
DeferTask | RescheduleTask) -> None:
+        # Capture the message BEFORE the API call so the recovery dispatcher
+        # in `update_task_state_if_needed` can re-issue it if the call raises
+        # (network blip, transient server 5xx). Clear the pending slot and
+        # record the resulting state only after the call returns successfully.
+        self._pending_terminal_state_msg = msg
+        if isinstance(msg, SucceedTask):
+            self.client.task_instances.succeed(
+                id=self.id,
+                when=msg.end_date,
+                task_outlets=msg.task_outlets,
+                outlet_events=msg.outlet_events,
+                rendered_map_index=self._rendered_map_index,
+            )
+            self._terminal_state = msg.state
+        elif isinstance(msg, RetryTask):
+            self.client.task_instances.retry(
+                id=self.id,
+                end_date=msg.end_date,
+                rendered_map_index=self._rendered_map_index,
+                retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
+                retry_reason=getattr(msg, "retry_reason", None),
+            )
+            self._terminal_state = msg.state
+        elif isinstance(msg, DeferTask):
+            self.client.task_instances.defer(self.id, msg)
+            self._terminal_state = TaskInstanceState.DEFERRED
+        elif isinstance(msg, RescheduleTask):
+            self.client.task_instances.reschedule(self.id, msg)
+            self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
+        self._pending_terminal_state_msg = None
+
+    def _replay_pending_terminal_state_msg(self) -> None:
+        """
+        Re-issue the dedicated API call for an unsynced terminal-state msg.
+
+        Best-effort — if the second attempt also fails the exception is
+        logged and we move on; the supervisor's overall failure handling
+        (heartbeat, exit-code reporting) will eventually surface the issue.
+        """
+        msg = self._pending_terminal_state_msg
+        if msg is None:
+            return
+        try:
+            self._send_terminal_state_msg(msg)
+        except Exception:
+            log.exception(
+                "Recovery retry of terminal-state API call failed; TI may be 
stuck on the server",
+                ti_id=self.id,
+                msg_type=type(msg).__name__,
+            )
+
     def _upload_logs(self):
         """
         Upload all log files found to the remote storage.
@@ -1389,29 +1466,20 @@ class ActivitySubprocess(WatchedSubprocess):
         resp: BaseModel | None = None
         dump_opts = {}
         if isinstance(msg, TaskState):
+            # No direct API call here — the recovery path in
+            # `update_task_state_if_needed` will call `finish()` for
+            # non-direct states (FAILED, etc.) once the subprocess exits.
             self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
             self._rendered_map_index = msg.rendered_map_index
         elif isinstance(msg, SucceedTask):
-            self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
             self._rendered_map_index = msg.rendered_map_index
-            self.client.task_instances.succeed(
-                id=self.id,
-                when=msg.end_date,
-                task_outlets=msg.task_outlets,
-                outlet_events=msg.outlet_events,
-                rendered_map_index=self._rendered_map_index,
-            )
+            self._send_terminal_state_msg(msg)
         elif isinstance(msg, RetryTask):
-            self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
             self._rendered_map_index = msg.rendered_map_index
-            self.client.task_instances.retry(
-                id=self.id,
-                end_date=msg.end_date,
-                rendered_map_index=self._rendered_map_index,
-            )
+            self._send_terminal_state_msg(msg)
         elif isinstance(msg, GetConnection):
             conn = self.client.connections.get(msg.conn_id)
             if isinstance(conn, ConnectionResponse):
@@ -1463,12 +1531,10 @@ class ActivitySubprocess(WatchedSubprocess):
             )
             resp = XComSequenceSliceResult.from_response(xcoms)
         elif isinstance(msg, DeferTask):
-            self._terminal_state = TaskInstanceState.DEFERRED
             self._rendered_map_index = msg.rendered_map_index
-            self.client.task_instances.defer(self.id, msg)
+            self._send_terminal_state_msg(msg)
         elif isinstance(msg, RescheduleTask):
-            self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
-            self.client.task_instances.reschedule(self.id, msg)
+            self._send_terminal_state_msg(msg)
         elif isinstance(msg, SkipDownstreamTasks):
             self.client.task_instances.skip_downstream_tasks(self.id, msg)
         elif isinstance(msg, SetXCom):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index f61b257b71b..05aa87c5cfb 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2793,6 +2793,137 @@ class TestHandleRequest:
         # Should not raise StopIteration (which would mean the loop crashed).
         generator.send(req2)
 
+    @pytest.mark.parametrize(
+        ("msg", "api_method", "expected_state"),
+        [
+            pytest.param(
+                SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                "succeed",
+                TaskInstanceState.SUCCESS,
+                id="succeed",
+            ),
+            pytest.param(
+                RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                "retry",
+                TaskInstanceState.UP_FOR_RETRY,
+                id="retry",
+            ),
+            pytest.param(
+                DeferTask(
+                    next_method="execute_complete",
+                    
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
+                    trigger_kwargs={},
+                ),
+                "defer",
+                TaskInstanceState.DEFERRED,
+                id="defer",
+            ),
+            pytest.param(
+                RescheduleTask(
+                    reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
+                    end_date=timezone.parse("2024-10-31T12:00:00Z"),
+                ),
+                "reschedule",
+                TaskInstanceState.UP_FOR_RESCHEDULE,
+                id="reschedule",
+            ),
+        ],
+    )
+    def test_terminal_state_not_set_when_direct_api_fails(
+        self, watched_subprocess, mocker, msg, api_method, expected_state
+    ):
+        """`_terminal_state` must NOT be set when the dedicated terminal-state
+        API raises.
+
+        The original message is captured in `_pending_terminal_state_msg`
+        BEFORE the API call so the recovery dispatcher in
+        `update_task_state_if_needed` can re-issue it on subprocess exit.
+        Covers all four terminal-state message types.
+        """
+        watched_subprocess, _ = watched_subprocess
+        setattr(
+            watched_subprocess.client.task_instances,
+            api_method,
+            mocker.Mock(side_effect=httpx.ConnectError("connection refused")),
+        )
+
+        with pytest.raises(httpx.ConnectError):
+            watched_subprocess._handle_request(msg, mocker.Mock(), req_id=1)
+
+        assert watched_subprocess._terminal_state is None
+        # Pending msg preserved so the recovery dispatcher can re-issue.
+        assert watched_subprocess._pending_terminal_state_msg is msg
+
+    @pytest.mark.parametrize(
+        ("msg", "api_method", "expected_state"),
+        [
+            pytest.param(
+                SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                "succeed",
+                TaskInstanceState.SUCCESS,
+                id="succeed",
+            ),
+            pytest.param(
+                RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                "retry",
+                TaskInstanceState.UP_FOR_RETRY,
+                id="retry",
+            ),
+            pytest.param(
+                DeferTask(
+                    next_method="execute_complete",
+                    
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
+                    trigger_kwargs={},
+                ),
+                "defer",
+                TaskInstanceState.DEFERRED,
+                id="defer",
+            ),
+            pytest.param(
+                RescheduleTask(
+                    reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
+                    end_date=timezone.parse("2024-10-31T12:00:00Z"),
+                ),
+                "reschedule",
+                TaskInstanceState.UP_FOR_RESCHEDULE,
+                id="reschedule",
+            ),
+        ],
+    )
+    def test_update_task_state_replays_pending_terminal_state_call(
+        self, watched_subprocess, mocker, msg, api_method, expected_state
+    ):
+        """If a direct terminal-state API call was attempted and raised, the
+        recovery dispatcher must re-issue the dedicated endpoint (not
+        `finish()`, which the server-side endpoint refuses for SUCCESS /
+        DEFERRED / SERVER_TERMINATED). Covers all four message types.
+        """
+        watched_subprocess, _ = watched_subprocess
+        watched_subprocess._exit_code = 0
+        # Simulate the failure scenario: original API call raised, msg 
preserved.
+        watched_subprocess._pending_terminal_state_msg = msg
+
+        watched_subprocess.update_task_state_if_needed()
+
+        # Recovery re-issues the dedicated endpoint, NOT finish().
+        getattr(watched_subprocess.client.task_instances, 
api_method).assert_called_once()
+        watched_subprocess.client.task_instances.finish.assert_not_called()
+        assert watched_subprocess._terminal_state == expected_state
+        assert watched_subprocess._pending_terminal_state_msg is None
+
+    def test_update_task_state_no_recovery_without_pending_msg(self, 
watched_subprocess, mocker):
+        """No replay when nothing was pending — preserves the original
+        STATES_SENT_DIRECTLY short-circuit for the happy path."""
+        watched_subprocess, _ = watched_subprocess
+        watched_subprocess._exit_code = 0
+        watched_subprocess._terminal_state = TaskInstanceState.SUCCESS
+        watched_subprocess._pending_terminal_state_msg = None
+
+        watched_subprocess.update_task_state_if_needed()
+
+        watched_subprocess.client.task_instances.finish.assert_not_called()
+        watched_subprocess.client.task_instances.succeed.assert_not_called()
+
 
 class TestSetSupervisorComms:
     class DummyComms:

Reply via email to