This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6132261fb49 Rename task-failure handlers so naming matches the call
hierarchy (#68264)
6132261fb49 is described below
commit 6132261fb49e6ffd46846b6c4cee9832fc472b96
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 9 12:41:07 2026 +0100
Rename task-failure handlers so naming matches the call hierarchy (#68264)
After #68254 routed the non-deferrable TriggerDagRunOperator failed-state
branch
through the retry-policy wrapper, _handle_current_task_failed had zero
direct
external callers -- every failure path now enters through
_apply_retry_policy_or_default. The names were inverted relative to the call
hierarchy: the function named like the umbrella ("handle current task
failed")
was the private mechanics primitive, and the one named like a sub-step was
the
actual entry point.
Rename so the umbrella name is the entry point:
- _apply_retry_policy_or_default -> _handle_current_task_failed (entry
point:
evaluate the policy, then fall back to the standard decision)
- old _handle_current_task_failed -> _finalize_task_failure (mechanics: emit
metrics, build the retry-or-fail message)
Pure rename, no behavior change.
---
.../src/airflow/sdk/execution_time/task_runner.py | 36 +++++++++++++---------
.../task_sdk/definitions/test_retry_policy.py | 20 ++++++------
2 files changed, 32 insertions(+), 24 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 5ab862ef6e4..48822c52f5d 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1533,7 +1533,7 @@ def run(
# We should allow retries if the task has defined it.
log.exception("Task failed with exception")
log.info("::group::Post Execute")
- msg, state = _apply_retry_policy_or_default(ti, e, log, context)
+ msg, state = _handle_current_task_failed(ti, e, log, context)
error = e
except AirflowTaskTerminated as e:
# External state updates are already handled with `ti_heartbeat` and
will be
@@ -1553,12 +1553,12 @@ def run(
# SystemExit needs to be retried if they are eligible.
log.error("Task exited", exit_code=e.code)
log.info("::group::Post Execute")
- msg, state = _apply_retry_policy_or_default(ti, e, log, context)
+ msg, state = _handle_current_task_failed(ti, e, log, context)
error = e
except BaseException as e:
log.exception("Task failed with exception")
log.info("::group::Post Execute")
- msg, state = _apply_retry_policy_or_default(ti, e, log, context)
+ msg, state = _handle_current_task_failed(ti, e, log, context)
error = e
finally:
# `state` may still be unset if an exception handler above raised
before
@@ -1667,20 +1667,20 @@ def _evaluate_retry_policy(
return None
-def _apply_retry_policy_or_default(
+def _handle_current_task_failed(
ti: RuntimeTaskInstance,
exception: BaseException,
log: Logger,
context: Context | None = None,
) -> tuple[RetryTask | TaskState, TaskInstanceState]:
"""
- Evaluate the retry policy (if any) and decide the task's next state.
+ Handle a failed task: evaluate the retry policy (if any) and decide the
next state.
- When the policy returns FAIL the task is marked as failed immediately,
- bypassing the normal retry-count check. When it returns RETRY with
- a custom delay, that delay is forwarded in the ``RetryTask`` message.
- For DEFAULT (or when no policy is configured), the standard
- ``_handle_current_task_failed`` logic runs.
+ This is the entry point every failure path routes through. When the policy
+ returns FAIL the task is marked as failed immediately, bypassing the normal
+ retry-count check. When it returns RETRY with a custom delay, that delay
is
+ forwarded in the ``RetryTask`` message. For DEFAULT (or when no policy is
+ configured), the standard ``_finalize_task_failure`` logic runs.
"""
from airflow.sdk.definitions.retry_policy import RetryAction
@@ -1696,17 +1696,25 @@ def _apply_retry_policy_or_default(
TaskInstanceState.FAILED,
)
if decision is not None and decision.action == RetryAction.RETRY:
- return _handle_current_task_failed(
+ return _finalize_task_failure(
ti, retry_delay_override=decision.retry_delay,
retry_reason=decision.reason
)
- return _handle_current_task_failed(ti)
+ return _finalize_task_failure(ti)
-def _handle_current_task_failed(
+def _finalize_task_failure(
ti: RuntimeTaskInstance,
retry_delay_override: timedelta | None = None,
retry_reason: str | None = None,
) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]:
+ """
+ Record failure metrics and build the standard retry-or-fail outcome.
+
+ Returns an ``UP_FOR_RETRY`` ``RetryTask`` when the server marked the task
+ retry-eligible (optionally carrying a policy-supplied delay/reason), else a
+ ``FAILED`` ``TaskState``. This is the default path; retry-policy overrides
+ are decided in :func:`_handle_current_task_failed` before this is called.
+ """
end_date = datetime.now(tz=timezone.utc)
ti.end_date = end_date
@@ -1820,7 +1828,7 @@ def _handle_trigger_dag_run(
# therefore honours a configured retry_policy. Synthesize the
same
# exception here so non-deferrable waits evaluate the policy
too,
# falling back to the standard retry-count check when none is
set.
- return _apply_retry_policy_or_default(
+ return _handle_current_task_failed(
ti,
AirflowException(f"{drte.trigger_dag_id} failed with
failed state {comms_msg.state}"),
log,
diff --git a/task-sdk/tests/task_sdk/definitions/test_retry_policy.py
b/task-sdk/tests/task_sdk/definitions/test_retry_policy.py
index f3fb2ac9a2f..251a5fa8bc0 100644
--- a/task-sdk/tests/task_sdk/definitions/test_retry_policy.py
+++ b/task-sdk/tests/task_sdk/definitions/test_retry_policy.py
@@ -31,8 +31,8 @@ from airflow.sdk.definitions.retry_policy import (
RetryRule,
)
from airflow.sdk.execution_time.task_runner import (
- _apply_retry_policy_or_default,
_evaluate_retry_policy,
+ _handle_current_task_failed,
)
log = structlog.get_logger("test")
@@ -334,16 +334,16 @@ class TestEvaluateRetryPolicy:
# ---------------------------------------------------------------------------
-# _apply_retry_policy_or_default (task_runner.py)
+# _handle_current_task_failed (task_runner.py)
# ---------------------------------------------------------------------------
-class TestApplyRetryPolicyOrDefault:
+class TestHandleCurrentTaskFailed:
def test_fail_bypasses_retry_count(self):
"""Policy FAIL overrides should_retry=True -- task fails
immediately."""
policy =
ExceptionRetryPolicy(rules=[RetryRule(exception=PermissionError,
action=RetryAction.FAIL)])
ti = _make_mock_ti(policy=policy, should_retry=True)
- msg, state = _apply_retry_policy_or_default(ti,
PermissionError("denied"), log)
+ msg, state = _handle_current_task_failed(ti,
PermissionError("denied"), log)
assert state == TaskInstanceState.FAILED
def test_retry_with_delay_and_reason(self):
@@ -358,7 +358,7 @@ class TestApplyRetryPolicyOrDefault:
]
)
ti = _make_mock_ti(policy=policy)
- msg, state = _apply_retry_policy_or_default(ti,
ConnectionError("refused"), log)
+ msg, state = _handle_current_task_failed(ti,
ConnectionError("refused"), log)
assert state == TaskInstanceState.UP_FOR_RETRY
assert msg.retry_delay_seconds == 42.0
assert msg.retry_reason == "net error"
@@ -372,14 +372,14 @@ class TestApplyRetryPolicyOrDefault:
]
)
ti = _make_mock_ti(policy=policy)
- msg, state = _apply_retry_policy_or_default(ti, ValueError("bad"), log)
+ msg, state = _handle_current_task_failed(ti, ValueError("bad"), log)
assert state == TaskInstanceState.UP_FOR_RETRY
assert len(msg.retry_reason) == 500
def test_default_falls_through_to_standard_retry(self):
policy = ExceptionRetryPolicy(rules=[]) # no match -> DEFAULT
ti = _make_mock_ti(policy=policy, should_retry=True)
- msg, state = _apply_retry_policy_or_default(ti, ValueError("bad"), log)
+ msg, state = _handle_current_task_failed(ti, ValueError("bad"), log)
assert state == TaskInstanceState.UP_FOR_RETRY
assert msg.retry_delay_seconds is None # no override
@@ -391,15 +391,15 @@ class TestApplyRetryPolicyOrDefault:
]
)
ti = _make_mock_ti(policy=policy, should_retry=False)
- msg, state = _apply_retry_policy_or_default(ti, ConnectionError("x"),
log)
+ msg, state = _handle_current_task_failed(ti, ConnectionError("x"), log)
assert state == TaskInstanceState.FAILED
def test_no_policy_with_should_retry_true(self):
ti = _make_mock_ti(policy=None, should_retry=True)
- msg, state = _apply_retry_policy_or_default(ti, ValueError("bad"), log)
+ msg, state = _handle_current_task_failed(ti, ValueError("bad"), log)
assert state == TaskInstanceState.UP_FOR_RETRY
def test_no_policy_with_should_retry_false(self):
ti = _make_mock_ti(policy=None, should_retry=False)
- msg, state = _apply_retry_policy_or_default(ti, ValueError("bad"), log)
+ msg, state = _handle_current_task_failed(ti, ValueError("bad"), log)
assert state == TaskInstanceState.FAILED