This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ed2d35fa1ad feat(hitl): add fail_on_reject to ApprovalOperator (#55255)
ed2d35fa1ad is described below

commit ed2d35fa1ad918750cf1a713caa001b181ec6ef2
Author: Wei Lee <[email protected]>
AuthorDate: Thu Sep 18 16:48:06 2025 +0800

    feat(hitl): add fail_on_reject to ApprovalOperator (#55255)
---
 .../src/airflow/providers/standard/exceptions.py   |  4 +
 .../airflow/providers/standard/operators/hitl.py   | 37 ++++++++-
 .../tests/unit/standard/operators/test_hitl.py     | 91 ++++++++++++++++++----
 3 files changed, 114 insertions(+), 18 deletions(-)

diff --git a/providers/standard/src/airflow/providers/standard/exceptions.py 
b/providers/standard/src/airflow/providers/standard/exceptions.py
index 98449fed173..4afef460324 100644
--- a/providers/standard/src/airflow/providers/standard/exceptions.py
+++ b/providers/standard/src/airflow/providers/standard/exceptions.py
@@ -63,3 +63,7 @@ class HITLTriggerEventError(Exception):
 
 class HITLTimeoutError(HITLTriggerEventError):
     """Raised when HITLOperator timeouts."""
+
+
+class HITLRejectException(AirflowException):
+    """Raised when an ApprovalOperator receives a "Reject" response when 
fail_on_reject is set to True."""
diff --git 
a/providers/standard/src/airflow/providers/standard/operators/hitl.py 
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index f40abe1e467..f76d62d8a16 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any
 from urllib.parse import ParseResult, urlencode, urlparse, urlunparse
 
 from airflow.configuration import conf
-from airflow.providers.standard.exceptions import HITLTimeoutError, 
HITLTriggerEventError
+from airflow.providers.standard.exceptions import HITLRejectException, 
HITLTimeoutError, HITLTriggerEventError
 from airflow.providers.standard.operators.branch import BranchMixIn
 from airflow.providers.standard.triggers.hitl import HITLTrigger, 
HITLTriggerEventSuccessPayload
 from airflow.providers.standard.utils.skipmixin import SkipMixin
@@ -303,12 +303,42 @@ class ApprovalOperator(HITLOperator, SkipMixin):
     APPROVE = "Approve"
     REJECT = "Reject"
 
-    def __init__(self, ignore_downstream_trigger_rules: bool = False, 
**kwargs) -> None:
+    def __init__(
+        self,
+        *,
+        ignore_downstream_trigger_rules: bool = False,
+        fail_on_reject: bool = False,
+        **kwargs,
+    ) -> None:
+        """
+        Human-in-the-loop Operator for simple approval workflows.
+
+        This operator presents the user with two fixed options: "Approve" and 
"Reject".
+
+        Behavior:
+        - "Approve": Downstream tasks execute as normal.
+        - "Reject":
+            - Downstream tasks are skipped according to the 
`ignore_downstream_trigger_rules` setting.
+            - If `fail_on_reject=True`, the task fails instead of only 
skipping downstream tasks.
+
+        Warning:
+            Using `fail_on_reject=True` is generally discouraged. A 
HITLOperator's role is to collect
+            human input, and receiving any response—including 
"Reject"—indicates the task succeeded.
+            Treating "Reject" as a task failure mixes human decision outcomes 
with Airflow task
+            success/failure states.
+            Only use this option if you explicitly intend for a "Reject" 
response to fail the task.
+
+        Args:
+            ignore_downstream_trigger_rules: If True, skips all downstream 
tasks regardless of trigger rules.
+            fail_on_reject: If True, the task fails when "Reject" is selected. 
Generally discouraged.
+                Read the warning carefully before using.
+        """
         for arg in self.FIXED_ARGS:
             if arg in kwargs:
                 raise ValueError(f"Passing {arg} to ApprovalOperator is not 
allowed.")
 
         self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
+        self.fail_on_reject = fail_on_reject
 
         super().__init__(
             options=[self.APPROVE, self.REJECT],
@@ -324,6 +354,9 @@ class ApprovalOperator(HITLOperator, SkipMixin):
             self.log.info("Approved. Proceeding with downstream tasks...")
             return ret
 
+        if self.fail_on_reject and chosen_option == self.REJECT:
+            raise HITLRejectException('Receive "Reject"')
+
         if not self.downstream_task_ids:
             self.log.info("No downstream tasks; nothing to do.")
             return ret
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py 
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index f198cc461cf..c7c03df683e 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import pytest
 
-from airflow.providers.standard.exceptions import HITLTimeoutError, 
HITLTriggerEventError
+from airflow.providers.standard.exceptions import HITLRejectException, 
HITLTimeoutError, HITLTriggerEventError
 
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
 
@@ -46,12 +46,15 @@ from airflow.providers.standard.operators.hitl import (
 from airflow.sdk import Param, timezone
 from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.execution_time.hitl import HITLUser
+from airflow.utils.context import Context
 
 from tests_common.test_utils.config import conf_vars
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
+    from airflow.sdk.definitions.context import Context
+
     from tests_common.pytest_plugin import DagMaker
 
 pytestmark = pytest.mark.db_test
@@ -77,6 +80,37 @@ def hitl_task_and_ti_for_generating_link(dag_maker: 
DagMaker) -> tuple[HITLOpera
     return task, dag_maker.run_ti(task.task_id, dr)
 
 
[email protected]
+def get_context_from_model_ti(mock_supervisor_comms):
+    def _get_context(ti: TaskInstance) -> Context:
+        from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+            DagRun as DRDataModel,
+            TaskInstance as TIDataModel,
+            TIRunContext,
+        )
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        # make mypy happy
+        assert ti is not None
+
+        dag_run = ti.dag_run
+        ti_model = TIDataModel.model_validate(ti, from_attributes=True)
+        runtime_ti = RuntimeTaskInstance.model_construct(
+            **ti_model.model_dump(exclude_unset=True),
+            task=ti.task,
+            _ti_context_from_server=TIRunContext(
+                dag_run=DRDataModel.model_validate(dag_run, 
from_attributes=True),
+                max_tries=ti.max_tries,
+                variables=[],
+                connections=[],
+                xcom_keys_to_clear=[],
+            ),
+        )
+        return runtime_ti.get_template_context()
+
+    return _get_context
+
+
 class TestHITLOperator:
     def test_validate_options(self) -> None:
         hitl_op = HITLOperator(
@@ -451,7 +485,9 @@ class TestApprovalOperator:
             "responded_by_user": {"id": "test", "name": "test"},
         }
 
-    def test_execute_complete_with_downstream_tasks(self, dag_maker) -> None:
+    def test_execute_complete_with_downstream_tasks(
+        self, dag_maker: DagMaker, get_context_from_model_ti
+    ) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             hitl_op = ApprovalOperator(
                 task_id="hitl_test",
@@ -461,10 +497,9 @@ class TestApprovalOperator:
 
         dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance("hitl_test")
-
         with pytest.raises(DownstreamTasksSkipped) as exc_info:
             hitl_op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["Reject"],
                     "params_input": {},
@@ -474,6 +509,26 @@ class TestApprovalOperator:
             )
         assert set(exc_info.value.tasks) == {"op1"}
 
+    def test_execute_complete_with_fail_on_reject_set_to_true(
+        self, dag_maker: DagMaker, get_context_from_model_ti
+    ) -> None:
+        with dag_maker("hitl_test_dag", serialized=True):
+            hitl_op = ApprovalOperator(task_id="hitl_test", subject="This is 
subject", fail_on_reject=True)
+            (hitl_op >> EmptyOperator(task_id="op1"))
+
+        dr = dag_maker.create_dagrun()
+        ti = dr.get_task_instance("hitl_test")
+        with pytest.raises(HITLRejectException):
+            hitl_op.execute_complete(
+                context=get_context_from_model_ti(ti),
+                event={
+                    "chosen_options": ["Reject"],
+                    "params_input": {},
+                    "responded_at": timezone.utcnow(),
+                    "responded_by_user": {"id": "test", "name": "test"},
+                },
+            )
+
 
 class TestHITLEntryOperator:
     def test_init_without_options_and_default(self) -> None:
@@ -513,7 +568,7 @@ class TestHITLEntryOperator:
 
 
 class TestHITLBranchOperator:
-    def test_execute_complete(self, dag_maker) -> None:
+    def test_execute_complete(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             branch_op = HITLBranchOperator(
                 task_id="make_choice",
@@ -527,7 +582,7 @@ class TestHITLBranchOperator:
         ti = dr.get_task_instance("make_choice")
         with pytest.raises(DownstreamTasksSkipped) as exc_info:
             branch_op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["branch_1"],
                     "params_input": {},
@@ -537,7 +592,9 @@ class TestHITLBranchOperator:
             )
         assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in 
range(2, 6))
 
-    def test_execute_complete_with_multiple_branches(self, dag_maker) -> None:
+    def test_execute_complete_with_multiple_branches(
+        self, dag_maker: DagMaker, get_context_from_model_ti
+    ) -> None:
         with dag_maker("hitl_test_dag", serialized=True):
             branch_op = HITLBranchOperator(
                 task_id="make_choice",
@@ -554,7 +611,7 @@ class TestHITLBranchOperator:
         ti = dr.get_task_instance("make_choice")
         with pytest.raises(DownstreamTasksSkipped) as exc_info:
             branch_op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": [f"branch_{i}" for i in range(1, 4)],
                     "params_input": {},
@@ -564,7 +621,7 @@ class TestHITLBranchOperator:
             )
         assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in 
range(4, 6))
 
-    def test_mapping_applies_for_single_choice(self, dag_maker):
+    def test_mapping_applies_for_single_choice(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
         # ["Approve"]; map -> "publish"
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -580,7 +637,7 @@ class TestHITLBranchOperator:
 
         with pytest.raises(DownstreamTasksSkipped) as exc:
             op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["Approve"],
                     "params_input": {},
@@ -591,7 +648,7 @@ class TestHITLBranchOperator:
         # checks to see that the "archive" task was skipped
         assert set(exc.value.tasks) == {("archive", -1)}
 
-    def test_mapping_with_multiple_choices(self, dag_maker):
+    def test_mapping_with_multiple_choices(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
         # multiple=True; mapping applied per option; no dedup implied
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -612,7 +669,7 @@ class TestHITLBranchOperator:
 
         with pytest.raises(DownstreamTasksSkipped) as exc:
             op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["Approve", "KeepAsIs"],
                     "params_input": {},
@@ -623,7 +680,7 @@ class TestHITLBranchOperator:
         # publish + keep chosen → only "other" skipped
         assert set(exc.value.tasks) == {("other", -1)}
 
-    def test_fallback_to_option_when_not_mapped(self, dag_maker):
+    def test_fallback_to_option_when_not_mapped(self, dag_maker: DagMaker, 
get_context_from_model_ti) -> None:
         # No mapping: option must match downstream task_id
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -638,7 +695,7 @@ class TestHITLBranchOperator:
 
         with pytest.raises(DownstreamTasksSkipped) as exc:
             op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["branch_2"],
                     "params_input": {},
@@ -648,7 +705,9 @@ class TestHITLBranchOperator:
             )
         assert set(exc.value.tasks) == {("branch_1", -1)}
 
-    def test_error_if_mapped_branch_not_direct_downstream(self, dag_maker):
+    def test_error_if_mapped_branch_not_direct_downstream(
+        self, dag_maker: DagMaker, get_context_from_model_ti
+    ):
         # Don't add the mapped task downstream → expect a clean error
         with dag_maker("hitl_map_dag", serialized=True):
             op = HITLBranchOperator(
@@ -664,7 +723,7 @@ class TestHITLBranchOperator:
 
         with pytest.raises(AirflowException, match="downstream|not found"):
             op.execute_complete(
-                context={"ti": ti, "task": ti.task},
+                context=get_context_from_model_ti(ti),
                 event={
                     "chosen_options": ["Approve"],
                     "params_input": {},

Reply via email to