This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ed2d35fa1ad feat(hitl): add fail_on_reject to ApprovalOperator (#55255)
ed2d35fa1ad is described below
commit ed2d35fa1ad918750cf1a713caa001b181ec6ef2
Author: Wei Lee <[email protected]>
AuthorDate: Thu Sep 18 16:48:06 2025 +0800
feat(hitl): add fail_on_reject to ApprovalOperator (#55255)
---
.../src/airflow/providers/standard/exceptions.py | 4 +
.../airflow/providers/standard/operators/hitl.py | 37 ++++++++-
.../tests/unit/standard/operators/test_hitl.py | 91 ++++++++++++++++++----
3 files changed, 114 insertions(+), 18 deletions(-)
diff --git a/providers/standard/src/airflow/providers/standard/exceptions.py
b/providers/standard/src/airflow/providers/standard/exceptions.py
index 98449fed173..4afef460324 100644
--- a/providers/standard/src/airflow/providers/standard/exceptions.py
+++ b/providers/standard/src/airflow/providers/standard/exceptions.py
@@ -63,3 +63,7 @@ class HITLTriggerEventError(Exception):
class HITLTimeoutError(HITLTriggerEventError):
"""Raised when HITLOperator timeouts."""
+
+
+class HITLRejectException(AirflowException):
+ """Raised when an ApprovalOperator receives a "Reject" response when
fail_on_reject is set to True."""
diff --git
a/providers/standard/src/airflow/providers/standard/operators/hitl.py
b/providers/standard/src/airflow/providers/standard/operators/hitl.py
index f40abe1e467..f76d62d8a16 100644
--- a/providers/standard/src/airflow/providers/standard/operators/hitl.py
+++ b/providers/standard/src/airflow/providers/standard/operators/hitl.py
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any
from urllib.parse import ParseResult, urlencode, urlparse, urlunparse
from airflow.configuration import conf
-from airflow.providers.standard.exceptions import HITLTimeoutError,
HITLTriggerEventError
+from airflow.providers.standard.exceptions import HITLRejectException,
HITLTimeoutError, HITLTriggerEventError
from airflow.providers.standard.operators.branch import BranchMixIn
from airflow.providers.standard.triggers.hitl import HITLTrigger,
HITLTriggerEventSuccessPayload
from airflow.providers.standard.utils.skipmixin import SkipMixin
@@ -303,12 +303,42 @@ class ApprovalOperator(HITLOperator, SkipMixin):
APPROVE = "Approve"
REJECT = "Reject"
- def __init__(self, ignore_downstream_trigger_rules: bool = False,
**kwargs) -> None:
+ def __init__(
+ self,
+ *,
+ ignore_downstream_trigger_rules: bool = False,
+ fail_on_reject: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ Human-in-the-loop Operator for simple approval workflows.
+
+ This operator presents the user with two fixed options: "Approve" and
"Reject".
+
+ Behavior:
+ - "Approve": Downstream tasks execute as normal.
+ - "Reject":
+ - Downstream tasks are skipped according to the
`ignore_downstream_trigger_rules` setting.
+ - If `fail_on_reject=True`, the task fails instead of only
skipping downstream tasks.
+
+ Warning:
+ Using `fail_on_reject=True` is generally discouraged. A
HITLOperator's role is to collect
+ human input, and receiving any response—including
"Reject"—indicates the task succeeded.
+ Treating "Reject" as a task failure mixes human decision outcomes
with Airflow task
+ success/failure states.
+ Only use this option if you explicitly intend for a "Reject"
response to fail the task.
+
+ Args:
+ ignore_downstream_trigger_rules: If True, skips all downstream
tasks regardless of trigger rules.
+ fail_on_reject: If True, the task fails when "Reject" is selected.
Generally discouraged.
+ Read the warning carefully before using.
+ """
for arg in self.FIXED_ARGS:
if arg in kwargs:
raise ValueError(f"Passing {arg} to ApprovalOperator is not
allowed.")
self.ignore_downstream_trigger_rules = ignore_downstream_trigger_rules
+ self.fail_on_reject = fail_on_reject
super().__init__(
options=[self.APPROVE, self.REJECT],
@@ -324,6 +354,9 @@ class ApprovalOperator(HITLOperator, SkipMixin):
self.log.info("Approved. Proceeding with downstream tasks...")
return ret
+ if self.fail_on_reject and chosen_option == self.REJECT:
+ raise HITLRejectException('Receive "Reject"')
+
if not self.downstream_task_ids:
self.log.info("No downstream tasks; nothing to do.")
return ret
diff --git a/providers/standard/tests/unit/standard/operators/test_hitl.py
b/providers/standard/tests/unit/standard/operators/test_hitl.py
index f198cc461cf..c7c03df683e 100644
--- a/providers/standard/tests/unit/standard/operators/test_hitl.py
+++ b/providers/standard/tests/unit/standard/operators/test_hitl.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import pytest
-from airflow.providers.standard.exceptions import HITLTimeoutError,
HITLTriggerEventError
+from airflow.providers.standard.exceptions import HITLRejectException,
HITLTimeoutError, HITLTriggerEventError
from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
@@ -46,12 +46,15 @@ from airflow.providers.standard.operators.hitl import (
from airflow.sdk import Param, timezone
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.execution_time.hitl import HITLUser
+from airflow.utils.context import Context
from tests_common.test_utils.config import conf_vars
if TYPE_CHECKING:
from sqlalchemy.orm import Session
+ from airflow.sdk.definitions.context import Context
+
from tests_common.pytest_plugin import DagMaker
pytestmark = pytest.mark.db_test
@@ -77,6 +80,37 @@ def hitl_task_and_ti_for_generating_link(dag_maker:
DagMaker) -> tuple[HITLOpera
return task, dag_maker.run_ti(task.task_id, dr)
[email protected]
+def get_context_from_model_ti(mock_supervisor_comms):
+ def _get_context(ti: TaskInstance) -> Context:
+ from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ DagRun as DRDataModel,
+ TaskInstance as TIDataModel,
+ TIRunContext,
+ )
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ # make mypy happy
+ assert ti is not None
+
+ dag_run = ti.dag_run
+ ti_model = TIDataModel.model_validate(ti, from_attributes=True)
+ runtime_ti = RuntimeTaskInstance.model_construct(
+ **ti_model.model_dump(exclude_unset=True),
+ task=ti.task,
+ _ti_context_from_server=TIRunContext(
+ dag_run=DRDataModel.model_validate(dag_run,
from_attributes=True),
+ max_tries=ti.max_tries,
+ variables=[],
+ connections=[],
+ xcom_keys_to_clear=[],
+ ),
+ )
+ return runtime_ti.get_template_context()
+
+ return _get_context
+
+
class TestHITLOperator:
def test_validate_options(self) -> None:
hitl_op = HITLOperator(
@@ -451,7 +485,9 @@ class TestApprovalOperator:
"responded_by_user": {"id": "test", "name": "test"},
}
- def test_execute_complete_with_downstream_tasks(self, dag_maker) -> None:
+ def test_execute_complete_with_downstream_tasks(
+ self, dag_maker: DagMaker, get_context_from_model_ti
+ ) -> None:
with dag_maker("hitl_test_dag", serialized=True):
hitl_op = ApprovalOperator(
task_id="hitl_test",
@@ -461,10 +497,9 @@ class TestApprovalOperator:
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("hitl_test")
-
with pytest.raises(DownstreamTasksSkipped) as exc_info:
hitl_op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["Reject"],
"params_input": {},
@@ -474,6 +509,26 @@ class TestApprovalOperator:
)
assert set(exc_info.value.tasks) == {"op1"}
+ def test_execute_complete_with_fail_on_reject_set_to_true(
+ self, dag_maker: DagMaker, get_context_from_model_ti
+ ) -> None:
+ with dag_maker("hitl_test_dag", serialized=True):
+ hitl_op = ApprovalOperator(task_id="hitl_test", subject="This is
subject", fail_on_reject=True)
+ (hitl_op >> EmptyOperator(task_id="op1"))
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("hitl_test")
+ with pytest.raises(HITLRejectException):
+ hitl_op.execute_complete(
+ context=get_context_from_model_ti(ti),
+ event={
+ "chosen_options": ["Reject"],
+ "params_input": {},
+ "responded_at": timezone.utcnow(),
+ "responded_by_user": {"id": "test", "name": "test"},
+ },
+ )
+
class TestHITLEntryOperator:
def test_init_without_options_and_default(self) -> None:
@@ -513,7 +568,7 @@ class TestHITLEntryOperator:
class TestHITLBranchOperator:
- def test_execute_complete(self, dag_maker) -> None:
+ def test_execute_complete(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
task_id="make_choice",
@@ -527,7 +582,7 @@ class TestHITLBranchOperator:
ti = dr.get_task_instance("make_choice")
with pytest.raises(DownstreamTasksSkipped) as exc_info:
branch_op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["branch_1"],
"params_input": {},
@@ -537,7 +592,9 @@ class TestHITLBranchOperator:
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in
range(2, 6))
- def test_execute_complete_with_multiple_branches(self, dag_maker) -> None:
+ def test_execute_complete_with_multiple_branches(
+ self, dag_maker: DagMaker, get_context_from_model_ti
+ ) -> None:
with dag_maker("hitl_test_dag", serialized=True):
branch_op = HITLBranchOperator(
task_id="make_choice",
@@ -554,7 +611,7 @@ class TestHITLBranchOperator:
ti = dr.get_task_instance("make_choice")
with pytest.raises(DownstreamTasksSkipped) as exc_info:
branch_op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": [f"branch_{i}" for i in range(1, 4)],
"params_input": {},
@@ -564,7 +621,7 @@ class TestHITLBranchOperator:
)
assert set(exc_info.value.tasks) == set((f"branch_{i}", -1) for i in
range(4, 6))
- def test_mapping_applies_for_single_choice(self, dag_maker):
+ def test_mapping_applies_for_single_choice(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
# ["Approve"]; map -> "publish"
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -580,7 +637,7 @@ class TestHITLBranchOperator:
with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["Approve"],
"params_input": {},
@@ -591,7 +648,7 @@ class TestHITLBranchOperator:
# checks to see that the "archive" task was skipped
assert set(exc.value.tasks) == {("archive", -1)}
- def test_mapping_with_multiple_choices(self, dag_maker):
+ def test_mapping_with_multiple_choices(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
# multiple=True; mapping applied per option; no dedup implied
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -612,7 +669,7 @@ class TestHITLBranchOperator:
with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["Approve", "KeepAsIs"],
"params_input": {},
@@ -623,7 +680,7 @@ class TestHITLBranchOperator:
# publish + keep chosen → only "other" skipped
assert set(exc.value.tasks) == {("other", -1)}
- def test_fallback_to_option_when_not_mapped(self, dag_maker):
+ def test_fallback_to_option_when_not_mapped(self, dag_maker: DagMaker,
get_context_from_model_ti) -> None:
# No mapping: option must match downstream task_id
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -638,7 +695,7 @@ class TestHITLBranchOperator:
with pytest.raises(DownstreamTasksSkipped) as exc:
op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["branch_2"],
"params_input": {},
@@ -648,7 +705,9 @@ class TestHITLBranchOperator:
)
assert set(exc.value.tasks) == {("branch_1", -1)}
- def test_error_if_mapped_branch_not_direct_downstream(self, dag_maker):
+ def test_error_if_mapped_branch_not_direct_downstream(
+ self, dag_maker: DagMaker, get_context_from_model_ti
+ ):
# Don't add the mapped task downstream → expect a clean error
with dag_maker("hitl_map_dag", serialized=True):
op = HITLBranchOperator(
@@ -664,7 +723,7 @@ class TestHITLBranchOperator:
with pytest.raises(AirflowException, match="downstream|not found"):
op.execute_complete(
- context={"ti": ti, "task": ti.task},
+ context=get_context_from_model_ti(ti),
event={
"chosen_options": ["Approve"],
"params_input": {},