This is an automated email from the ASF dual-hosted git repository.
vatsrahul1001 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 173c2a1806d Recover stuck TIs when direct terminal-state API call
fails (#66574)
173c2a1806d is described below
commit 173c2a1806dd087272ec287fb923917630ef8f81
Author: Jarek Potiuk <[email protected]>
AuthorDate: Tue May 19 20:35:42 2026 +0200
Recover stuck TIs when direct terminal-state API call fails (#66574)
* Recover stuck TIs when direct terminal-state API call fails
The supervisor's _handle_request for SucceedTask, RetryTask, DeferTask,
and RescheduleTask set _terminal_state BEFORE calling the matching
client.task_instances.{succeed,retry,defer,reschedule}() API. If that
API call raised (transient network blip, server 5xx, etc.),
_terminal_state was set on the supervisor but the server never saw
the transition. The supervisor's update_task_state_if_needed then
saw final_state in STATES_SENT_DIRECTLY and short-circuited the
recovery finish() call -- leaving the TaskInstance stuck RUNNING
on the server forever, blocking downstream dependencies and
triggering false alerts.
Two-part fix:
1. Make the direct API call FIRST. Only set _terminal_state and the
new _terminal_state_synced_to_server flag after the call returns
successfully. If the API raises, both stay unset and the exception
propagates to handle_requests, where the existing catch-all sends
an ErrorResponse to the task subprocess.
2. Have update_task_state_if_needed always call finish() when
_terminal_state_synced_to_server is False, regardless of what
final_state happens to return. The finish() API takes the state
value, so a SUCCESS / DEFERRED / etc. transition that originally
failed is re-attempted via finish() on subprocess exit.
Pre-existing semantics for the no-direct-API states (FAILED,
UP_FOR_RETRY without RetryTask, etc.) preserved -- those land in
the same finish() branch.
Tests added:
- _terminal_state not set when succeed() raises.
- update_task_state_if_needed calls finish() when synced flag is
False, even with final_state == SUCCESS.
- update_task_state_if_needed skips finish() when synced flag is
True (preserves the existing happy-path optimisation).
Reported by the L3 ASVS sweep at apache/tooling-agents#24 (FINDING-007).
* Refactor terminal-state dispatch and parametrize tests across all 4 states
Address review feedback on #66574:
- Extract `_send_terminal_state_msg` helper so the per-msg-type dispatch
for succeed / retry / defer / reschedule lives in one place. Both
`_handle_request` and `_replay_pending_terminal_state_msg` now go
through it instead of duplicating the four-branch isinstance chain.
- Parametrize the two recovery tests over all four terminal-state
message types (was only Succeed + Defer); add UP_FOR_RETRY and
UP_FOR_RESCHEDULE coverage.
* Narrow _pending_terminal_state_msg type to satisfy mypy
The field was annotated as BaseModel | None, but _send_terminal_state_msg
expects SucceedTask | RetryTask | DeferTask | RescheduleTask. mypy
couldn't prove the narrowing at the _replay_pending_terminal_state_msg
call site. Tighten the field type to the exact union the setter assigns
and the consumer accepts.
---------
Co-authored-by: vatsrahul1001 <[email protected]>
Co-authored-by: Rahul Vats <[email protected]>
---
.../src/airflow/sdk/execution_time/supervisor.py | 112 ++++++++++++++----
.../task_sdk/execution_time/test_supervisor.py | 131 +++++++++++++++++++++
2 files changed, 219 insertions(+), 24 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 7c9ecbeab9e..3e705b9d211 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1152,6 +1152,18 @@ class ActivitySubprocess(WatchedSubprocess):
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
+ # The terminal-state message currently being processed by
`_handle_request`,
+ # captured BEFORE the dedicated API call (succeed / retry / defer /
+ # reschedule). If the API call raises (network blip, server 5xx, etc.),
+ # this attribute stays set and the dispatcher in
+ # `update_task_state_if_needed` re-issues the matching API call on
+ # subprocess exit — re-attempting the original transition rather than
+ # falling back to `finish()`, which doesn't accept SUCCESS / DEFERRED /
+ # SERVER_TERMINATED on the server side. Cleared (and `_terminal_state`
+ # set) only after the API call returns successfully.
+ _pending_terminal_state_msg: SucceedTask | RetryTask | DeferTask |
RescheduleTask | None = attrs.field(
+ default=None, init=False
+ )
_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
@@ -1269,10 +1281,23 @@ class ActivitySubprocess(WatchedSubprocess):
return self._exit_code
def update_task_state_if_needed(self):
- # If the process has finished non-directly patched state (directly
means deferred, reschedule, etc.),
- # update the state of the TaskInstance to reflect the final state of
the process.
- # For states like `deferred`, `up_for_reschedule`, the process will
exit with 0, but the state will be updated
- # by the subprocess in the `handle_requests` method.
+ # If a direct-state API call (succeed / retry / defer / reschedule)
+ # was attempted but raised, `_pending_terminal_state_msg` still holds
+ # the original request. Re-issue the matching dedicated API call so
+ # the server learns the terminal state we couldn't deliver earlier.
+ # Without this recovery, a transient API failure during the direct
+ # call would leave the TI stuck RUNNING on the server — `finish()`
+ # cannot substitute because the server-side `finish` endpoint does
+ # not accept SUCCESS / DEFERRED / SERVER_TERMINATED transitions.
+ if self._pending_terminal_state_msg is not None:
+ self._replay_pending_terminal_state_msg()
+ return
+
+ # If the process has finished a non-directly-patched state (e.g.
+ # FAILED, UP_FOR_RETRY without RetryTask), `finish()` is the
+ # dedicated endpoint for those transitions. For states already in
+ # STATES_SENT_DIRECTLY whose direct API call succeeded, no further
+ # action is needed.
if self.final_state not in STATES_SENT_DIRECTLY:
self.client.task_instances.finish(
id=self.id,
@@ -1281,6 +1306,58 @@ class ActivitySubprocess(WatchedSubprocess):
rendered_map_index=self._rendered_map_index,
)
+ def _send_terminal_state_msg(self, msg: SucceedTask | RetryTask |
DeferTask | RescheduleTask) -> None:
+ # Capture the message BEFORE the API call so the recovery dispatcher
+ # in `update_task_state_if_needed` can re-issue it if the call raises
+ # (network blip, transient server 5xx). Clear the pending slot and
+ # record the resulting state only after the call returns successfully.
+ self._pending_terminal_state_msg = msg
+ if isinstance(msg, SucceedTask):
+ self.client.task_instances.succeed(
+ id=self.id,
+ when=msg.end_date,
+ task_outlets=msg.task_outlets,
+ outlet_events=msg.outlet_events,
+ rendered_map_index=self._rendered_map_index,
+ )
+ self._terminal_state = msg.state
+ elif isinstance(msg, RetryTask):
+ self.client.task_instances.retry(
+ id=self.id,
+ end_date=msg.end_date,
+ rendered_map_index=self._rendered_map_index,
+ retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
+ retry_reason=getattr(msg, "retry_reason", None),
+ )
+ self._terminal_state = msg.state
+ elif isinstance(msg, DeferTask):
+ self.client.task_instances.defer(self.id, msg)
+ self._terminal_state = TaskInstanceState.DEFERRED
+ elif isinstance(msg, RescheduleTask):
+ self.client.task_instances.reschedule(self.id, msg)
+ self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
+ self._pending_terminal_state_msg = None
+
+ def _replay_pending_terminal_state_msg(self) -> None:
+ """
+ Re-issue the dedicated API call for an unsynced terminal-state msg.
+
+ Best-effort — if the second attempt also fails the exception is
+ logged and we move on; the supervisor's overall failure handling
+ (heartbeat, exit-code reporting) will eventually surface the issue.
+ """
+ msg = self._pending_terminal_state_msg
+ if msg is None:
+ return
+ try:
+ self._send_terminal_state_msg(msg)
+ except Exception:
+ log.exception(
+ "Recovery retry of terminal-state API call failed; TI may be
stuck on the server",
+ ti_id=self.id,
+ msg_type=type(msg).__name__,
+ )
+
def _upload_logs(self):
"""
Upload all log files found to the remote storage.
@@ -1452,31 +1529,20 @@ class ActivitySubprocess(WatchedSubprocess):
resp: BaseModel | None = None
dump_opts: dict[str, bool] = {}
if isinstance(msg, TaskState):
+ # No direct API call here — the recovery path in
+ # `update_task_state_if_needed` will call `finish()` for
+ # non-direct states (FAILED, etc.) once the subprocess exits.
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, SucceedTask):
- self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
- self.client.task_instances.succeed(
- id=self.id,
- when=msg.end_date,
- task_outlets=msg.task_outlets,
- outlet_events=msg.outlet_events,
- rendered_map_index=self._rendered_map_index,
- )
+ self._send_terminal_state_msg(msg)
elif isinstance(msg, RetryTask):
- self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
- self.client.task_instances.retry(
- id=self.id,
- end_date=msg.end_date,
- rendered_map_index=self._rendered_map_index,
- retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
- retry_reason=getattr(msg, "retry_reason", None),
- )
+ self._send_terminal_state_msg(msg)
elif isinstance(msg, GetConnection):
resp, dump_opts = handle_get_connection(self.client, msg)
elif isinstance(msg, GetVariable):
@@ -1512,12 +1578,10 @@ class ActivitySubprocess(WatchedSubprocess):
)
resp = XComSequenceSliceResult.from_response(xcoms)
elif isinstance(msg, DeferTask):
- self._terminal_state = TaskInstanceState.DEFERRED
self._rendered_map_index = msg.rendered_map_index
- self.client.task_instances.defer(self.id, msg)
+ self._send_terminal_state_msg(msg)
elif isinstance(msg, RescheduleTask):
- self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
- self.client.task_instances.reschedule(self.id, msg)
+ self._send_terminal_state_msg(msg)
elif isinstance(msg, SkipDownstreamTasks):
self.client.task_instances.skip_downstream_tasks(self.id, msg)
elif isinstance(msg, SetXCom):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 0b3cd64a21e..811101dbf4c 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -3079,6 +3079,137 @@ class TestHandleRequest:
# Should not raise StopIteration (which would mean the loop crashed).
generator.send(req2)
+ @pytest.mark.parametrize(
+ ("msg", "api_method", "expected_state"),
+ [
+ pytest.param(
+ SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ "succeed",
+ TaskInstanceState.SUCCESS,
+ id="succeed",
+ ),
+ pytest.param(
+ RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ "retry",
+ TaskInstanceState.UP_FOR_RETRY,
+ id="retry",
+ ),
+ pytest.param(
+ DeferTask(
+ next_method="execute_complete",
+
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
+ trigger_kwargs={},
+ ),
+ "defer",
+ TaskInstanceState.DEFERRED,
+ id="defer",
+ ),
+ pytest.param(
+ RescheduleTask(
+ reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
+ ),
+ "reschedule",
+ TaskInstanceState.UP_FOR_RESCHEDULE,
+ id="reschedule",
+ ),
+ ],
+ )
+ def test_terminal_state_not_set_when_direct_api_fails(
+ self, watched_subprocess, mocker, msg, api_method, expected_state
+ ):
+ """`_terminal_state` must NOT be set when the dedicated terminal-state
+ API raises.
+
+ The original message is captured in `_pending_terminal_state_msg`
+ BEFORE the API call so the recovery dispatcher in
+ `update_task_state_if_needed` can re-issue it on subprocess exit.
+ Covers all four terminal-state message types.
+ """
+ watched_subprocess, _ = watched_subprocess
+ setattr(
+ watched_subprocess.client.task_instances,
+ api_method,
+ mocker.Mock(side_effect=httpx.ConnectError("connection refused")),
+ )
+
+ with pytest.raises(httpx.ConnectError):
+ watched_subprocess._handle_request(msg, mocker.Mock(), req_id=1)
+
+ assert watched_subprocess._terminal_state is None
+ # Pending msg preserved so the recovery dispatcher can re-issue.
+ assert watched_subprocess._pending_terminal_state_msg is msg
+
+ @pytest.mark.parametrize(
+ ("msg", "api_method", "expected_state"),
+ [
+ pytest.param(
+ SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ "succeed",
+ TaskInstanceState.SUCCESS,
+ id="succeed",
+ ),
+ pytest.param(
+ RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ "retry",
+ TaskInstanceState.UP_FOR_RETRY,
+ id="retry",
+ ),
+ pytest.param(
+ DeferTask(
+ next_method="execute_complete",
+
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
+ trigger_kwargs={},
+ ),
+ "defer",
+ TaskInstanceState.DEFERRED,
+ id="defer",
+ ),
+ pytest.param(
+ RescheduleTask(
+ reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
+ ),
+ "reschedule",
+ TaskInstanceState.UP_FOR_RESCHEDULE,
+ id="reschedule",
+ ),
+ ],
+ )
+ def test_update_task_state_replays_pending_terminal_state_call(
+ self, watched_subprocess, mocker, msg, api_method, expected_state
+ ):
+ """If a direct terminal-state API call was attempted and raised, the
+ recovery dispatcher must re-issue the dedicated endpoint (not
+ `finish()`, which the server-side endpoint refuses for SUCCESS /
+ DEFERRED / SERVER_TERMINATED). Covers all four message types.
+ """
+ watched_subprocess, _ = watched_subprocess
+ watched_subprocess._exit_code = 0
+ # Simulate the failure scenario: original API call raised, msg
preserved.
+ watched_subprocess._pending_terminal_state_msg = msg
+
+ watched_subprocess.update_task_state_if_needed()
+
+ # Recovery re-issues the dedicated endpoint, NOT finish().
+ getattr(watched_subprocess.client.task_instances,
api_method).assert_called_once()
+ watched_subprocess.client.task_instances.finish.assert_not_called()
+ assert watched_subprocess._terminal_state == expected_state
+ assert watched_subprocess._pending_terminal_state_msg is None
+
+ def test_update_task_state_no_recovery_without_pending_msg(self,
watched_subprocess, mocker):
+ """No replay when nothing was pending — preserves the original
+ STATES_SENT_DIRECTLY short-circuit for the happy path."""
+ watched_subprocess, _ = watched_subprocess
+ watched_subprocess._exit_code = 0
+ watched_subprocess._terminal_state = TaskInstanceState.SUCCESS
+ watched_subprocess._pending_terminal_state_msg = None
+
+ watched_subprocess.update_task_state_if_needed()
+
+ watched_subprocess.client.task_instances.finish.assert_not_called()
+ watched_subprocess.client.task_instances.succeed.assert_not_called()
+
class TestSetSupervisorComms:
class DummyComms: