This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 96feeaabb53e274351378b8e2f5dfd41985e8dc8 Author: Steven Schaerer <[email protected]> AuthorDate: Sat Mar 2 20:18:38 2024 +0100 Improve code coverage for TriggerRuleDep (#37680) (cherry picked from commit 73a632a5a0af3fe51484e89a7bd4c771800707e5) --- airflow/ti_deps/deps/trigger_rule_dep.py | 12 +- tests/ti_deps/deps/test_trigger_rule_dep.py | 1000 +++++++++++++++++---------- 2 files changed, 620 insertions(+), 392 deletions(-) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index ca2a6100a2..c381166634 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -435,8 +435,8 @@ class TriggerRuleDep(BaseTIDep): if success + failed <= 0: yield self._failing_status( reason=( - f"Task's trigger rule '{trigger_rule}'" - "requires at least one upstream task failure or success" + f"Task's trigger rule '{trigger_rule}' " + "requires at least one upstream task failure or success " f"but none were failed or success. upstream_states={upstream_states}, " f"upstream_task_ids={task.upstream_task_ids}" ) @@ -521,14 +521,6 @@ class TriggerRuleDep(BaseTIDep): f"upstream_task_ids={task.upstream_task_ids}" ) ) - elif upstream_setup is None: # for now, None only happens in mapped case - yield self._failing_status( - reason=( - f"Task's trigger rule '{trigger_rule}' cannot have mapped tasks as upstream. " - f"upstream_states={upstream_states}, " - f"upstream_task_ids={task.upstream_task_ids}" - ) - ) elif upstream_setup and not success_setup: yield self._failing_status( reason=( diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 00cbcd449a..7656dcdec9 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -35,12 +35,16 @@ from airflow.utils.trigger_rule import TriggerRule pytestmark = pytest.mark.db_test - if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + from airflow.models.dagrun import DagRun SKIPPED = TaskInstanceState.SKIPPED UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED +REMOVED = TaskInstanceState.REMOVED +SUCCESS = TaskInstanceState.SUCCESS +FAILED = TaskInstanceState.FAILED @pytest.fixture @@ -92,9 +96,21 @@ def get_task_instance(monkeypatch, session, dag_maker): @pytest.fixture def get_mapped_task_dagrun(session, dag_maker): - def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=TaskInstanceState.SUCCESS): + def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=SUCCESS, add_setup_tasks: bool = False): from airflow.decorators import task + @task + def setup_1(i): + return 1 + + @task + def setup_2(i): + return 1 + + @task + def setup_3(i): + return 1 + @task def do_something(i): return 1 @@ -106,46 +122,76 @@ def get_mapped_task_dagrun(session, dag_maker): with dag_maker(dag_id="test_dag"): nums = do_something.expand(i=[i + 1 for i in range(5)]) do_something_else.expand(i=nums) + if add_setup_tasks: + setup_nums = setup_1.expand(i=[i + 1 for i in range(5)]) + setup_more_nums = setup_2.expand(i=setup_nums) + setup_other_nums = setup_3.expand(i=setup_more_nums) + setup_more_nums.as_setup() >> nums + setup_nums.as_setup() >> nums + setup_other_nums.as_setup() >> nums dr = dag_maker.create_dagrun() - ti = dr.get_task_instance("do_something_else", session=session) - ti.map_index = 0 - for map_index in range(1, 5): - ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) - ti.dag_run = dr - session.add(ti) - session.flush() - tis = dr.get_task_instances() - for ti in tis: - if ti.task_id == "do_something": - if ti.map_index > 2: - ti.state = TaskInstanceState.REMOVED - else: - ti.state = state - session.merge(ti) + def _expand_tasks(task_instance: str, upstream: str) -> BaseOperator | None: + ti = dr.get_task_instance(task_instance, session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + tis = dr.get_task_instances(session=session) + for ti in tis: + if ti.task_id == upstream: + if ti.map_index > 2: + ti.state = REMOVED + else: + ti.state = state + session.merge(ti) + return ti.task + + do_task = _expand_tasks("do_something_else", "do_something") + if add_setup_tasks: + _expand_tasks("setup_2", "setup_1") + setup_task = _expand_tasks("setup_3", "setup_2") + else: + setup_task = None session.commit() - return dr, ti.task + return dr, do_task, setup_task return _get_dagrun class TestTriggerRuleDep: - def test_no_upstream_tasks(self, get_task_instance): + def test_no_upstream_tasks(self, session, get_task_instance): """ If the TI has no upstream TIs then there is nothing to check and the dep is passed """ ti = get_task_instance(TriggerRule.ALL_DONE) - assert TriggerRuleDep().is_met(ti=ti) + dep_statuses = tuple( + TriggerRuleDep().get_dep_statuses(ti=ti, dep_context=DepContext(), session=session) + ) + assert len(dep_statuses) == 1 + assert dep_statuses[0].passed + assert dep_statuses[0].reason == "The task instance did not have any upstream tasks." - def test_always_tr(self, get_task_instance): + def test_always_tr(self, session, get_task_instance): """ The always trigger rule should always pass this dep """ - ti = get_task_instance(TriggerRule.ALWAYS) - assert TriggerRuleDep().is_met(ti=ti) + ti = get_task_instance(TriggerRule.ALWAYS, normal_tasks=["a"]) - def test_one_success_tr_success(self, session, get_task_instance): + dep_statuses = tuple( + TriggerRuleDep().get_dep_statuses(ti=ti, dep_context=DepContext(), session=session) + ) + assert len(dep_statuses) == 1 + assert dep_statuses[0].passed + assert dep_statuses[0].reason == "The task had a always trigger rule set." + + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_one_success_tr_success( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ One-success trigger rule success """ @@ -158,39 +204,67 @@ class TestTriggerRuleDep: upstream_failed=2, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 0 - def test_one_success_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_one_success_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ One-success trigger rule failure """ + ti = get_task_instance( + TriggerRule.ONE_SUCCESS, + success=0, + skipped=1, + failed=1, + removed=1, + upstream_failed=1, + done=4, + ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires one upstream task success, but none were found.", + expected_ti_state=expected_ti_state, + ) + + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_one_success_tr_failure_all_skipped( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): + """ + One-success trigger rule failure and all are skipped + """ ti = get_task_instance( TriggerRule.ONE_SUCCESS, success=0, skipped=2, - failed=2, + failed=0, removed=0, - upstream_failed=2, + upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires one upstream task success, but none were found.", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_one_failure_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_one_failure_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ One-failure trigger rule failure """ @@ -203,17 +277,16 @@ class TestTriggerRuleDep: upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires one upstream task failure, but none were found.", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_one_failure_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_one_failure_tr_success(self, session, get_task_instance, flag_upstream_failed): """ One-failure trigger rule success """ @@ -226,16 +299,10 @@ class TestTriggerRuleDep: upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_one_failure_tr_success_no_failed(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_one_failure_tr_success_no_failed(self, session, get_task_instance, flag_upstream_failed): """ One-failure trigger rule success """ @@ -248,16 +315,10 @@ class TestTriggerRuleDep: upstream_failed=2, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_one_done_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_one_done_tr_success(self, session, get_task_instance, flag_upstream_failed): """ One-done trigger rule success """ @@ -270,16 +331,10 @@ class TestTriggerRuleDep: upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_one_done_tr_success_with_failed(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_one_done_tr_success_with_failed(self, session, get_task_instance, flag_upstream_failed): """ One-done trigger rule success """ @@ -292,16 +347,10 @@ class TestTriggerRuleDep: upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_one_done_tr_skip(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_one_done_tr_skip(self, session, get_task_instance, flag_upstream_failed, expected_ti_state): """ One-done trigger rule skip """ @@ -314,17 +363,18 @@ class TestTriggerRuleDep: upstream_failed=0, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires at least one upstream task failure or success but none", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_one_done_tr_upstream_failed(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_one_done_tr_upstream_failed( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ One-done trigger rule upstream_failed """ @@ -337,17 +387,16 @@ class TestTriggerRuleDep: upstream_failed=2, done=2, ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires at least one upstream task failure or success but none", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_all_success_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_all_success_tr_success(self, session, get_task_instance, flag_upstream_failed): """ All-success trigger rule success """ @@ -361,16 +410,14 @@ class TestTriggerRuleDep: done=1, normal_tasks=["FakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_all_success_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_all_success_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ All-success trigger rule failure """ @@ -384,20 +431,15 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have succeeded, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - @pytest.mark.parametrize( - "flag_upstream_failed, expected_ti_state", - [(True, TaskInstanceState.SKIPPED), (False, None)], - ) + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) def test_all_success_tr_skip(self, session, get_task_instance, flag_upstream_failed, expected_ti_state): """ All-success trigger rule fails when some upstream tasks are skipped. @@ -412,18 +454,18 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have succeeded, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - assert ti.state == expected_ti_state - def test_all_success_tr_skip_wait_for_past_depends_before_skipping(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_all_success_tr_skip_wait_for_past_depends_before_skipping( + self, session, get_task_instance, flag_upstream_failed + ): """ All-success trigger rule fails when some upstream tasks are skipped. The state of the ti should not be set to SKIPPED when flag_upstream_failed is True and @@ -442,21 +484,21 @@ class TestTriggerRuleDep: ti.task.xcom_pull.return_value = None xcom_mock = Mock(return_value=None) with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", xcom_mock): - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext( - flag_upstream_failed=True, wait_for_past_depends_before_skipping=True - ), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + wait_for_past_depends_before_skipping=True, + expected_reason=( + "Task should be skipped but the past depends are not met" + if flag_upstream_failed + else "requires all upstream tasks to have succeeded, but found 1" + ), ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - assert ti.state is None + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) def test_all_success_tr_skip_wait_for_past_depends_before_skipping_past_depends_met( - self, session, get_task_instance + self, session, get_task_instance, flag_upstream_failed, expected_ti_state ): """ All-success trigger rule fails when some upstream tasks are skipped. The state of the ti @@ -476,23 +518,19 @@ class TestTriggerRuleDep: ti.task.xcom_pull.return_value = None xcom_mock = Mock(return_value=True) with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", xcom_mock): - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext( - flag_upstream_failed=True, wait_for_past_depends_before_skipping=True - ), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + wait_for_past_depends_before_skipping=True, + expected_ti_state=expected_ti_state, + expected_reason="requires all upstream tasks to have succeeded, but found 1", ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - assert ti.state == TaskInstanceState.SKIPPED @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_none_failed_tr_success(self, session, get_task_instance, flag_upstream_failed): """ - All success including skip trigger rule success + None failed trigger rule success """ ti = get_task_instance( TriggerRule.NONE_FAILED, @@ -504,19 +542,16 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), - session=session, - ) - ) - assert len(dep_statuses) == 0 - assert ti.state is None + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_none_failed_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_none_failed_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ - All success including skip trigger rule failure + None failed trigger rule failure """ ti = get_task_instance( TriggerRule.NONE_FAILED, @@ -528,19 +563,45 @@ class TestTriggerRuleDep: done=3, normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="all upstream tasks to have succeeded or been skipped, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_none_failed_min_one_success_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_none_failed_tr_failure_with_upstream_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ - All success including skip trigger rule success + None failed skip trigger rule failure + """ + ti = get_task_instance( + TriggerRule.NONE_FAILED, + success=1, + skipped=1, + failed=0, + removed=0, + upstream_failed=1, + done=3, + normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], + ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="all upstream tasks to have succeeded or been skipped, but found 1", + expected_ti_state=expected_ti_state, + ) + + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_none_failed_min_one_success_tr_success(self, session, get_task_instance, flag_upstream_failed): + """ + None failed min one success trigger rule success """ ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, @@ -552,18 +613,14 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_none_failed_min_one_success_tr_skipped(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_none_failed_min_one_success_tr_skipped( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ - All success including all upstream skips trigger rule success + None failed min one success trigger rule success with all skipped """ ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, @@ -575,19 +632,21 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=True), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 0 - assert ti.state == TaskInstanceState.SKIPPED - def test_none_failed_min_one_success_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_none_failed_min_one_success_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ - All success including skip trigger rule failure + None failed min one success trigger rule failure due to single failure """ ti = get_task_instance( TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, @@ -599,17 +658,43 @@ class TestTriggerRuleDep: done=3, normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="all upstream tasks to have succeeded or been skipped, but found 1", + expected_ti_state=expected_ti_state, + ) + + @pytest.mark.parametrize( + "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), (False, None)] + ) + def test_none_failed_min_one_success_tr_upstream_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): + """ + None failed min one success trigger rule failure due to single upstream failure + """ + ti = get_task_instance( + TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS, + success=1, + skipped=1, + failed=0, + removed=0, + upstream_failed=1, + done=3, + normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], + ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="all upstream tasks to have succeeded or been skipped, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_all_failed_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_all_failed_tr_success(self, session, get_task_instance, flag_upstream_failed): """ All-failed trigger rule success """ @@ -623,16 +708,10 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_all_failed_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_all_failed_tr_failure(self, session, get_task_instance, flag_upstream_failed, expected_ti_state): """ All-failed trigger rule failure """ @@ -646,17 +725,16 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have failed, but found 2", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_all_done_tr_success(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_all_done_tr_success(self, session, get_task_instance, flag_upstream_failed): """ All-done trigger rule success """ @@ -670,14 +748,7 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "OtherFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) @pytest.mark.parametrize( "task_cfg, states, exp_reason, exp_state", @@ -761,8 +832,9 @@ class TestTriggerRuleDep: ), ], ) + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_teardown_tr_not_all_done( - self, task_cfg, states, exp_reason, exp_state, session, get_task_instance + self, task_cfg, states, exp_reason, exp_state, session, get_task_instance, flag_upstream_failed ): """ All-done trigger rule success @@ -773,22 +845,18 @@ class TestTriggerRuleDep: normal_tasks=[f"w{x}" for x in range(task_cfg["work"])], setup_tasks=[f"s{x}" for x in range(task_cfg["setup"])], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, dep_context=DepContext(flag_upstream_failed=True), session=session - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason=exp_reason, + expected_ti_state=exp_state if exp_state and flag_upstream_failed else None, ) - if exp_reason: - dep_status = dep_statuses[0] - assert len(dep_statuses) == 1 - assert exp_reason in dep_status.reason - assert dep_status.passed is False - assert ti.state == exp_state - else: - assert len(dep_statuses) == 0 - assert ti.state is None - def test_all_skipped_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_all_skipped_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ All-skipped trigger rule failure """ @@ -802,17 +870,18 @@ class TestTriggerRuleDep: done=1, normal_tasks=["FakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have been skipped, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_all_skipped_tr_failure_upstream_failed(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_all_skipped_tr_failure_upstream_failed( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ All-skipped trigger rule failure if an upstream task is in a `upstream_failed` state """ @@ -826,15 +895,13 @@ class TestTriggerRuleDep: done=1, normal_tasks=["FakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have been skipped, but found 1", + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_all_skipped_tr_success(self, session, get_task_instance, flag_upstream_failed): @@ -851,16 +918,10 @@ class TestTriggerRuleDep: done=3, normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - def test_all_done_tr_failure(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_all_done_tr_failure(self, session, get_task_instance, flag_upstream_failed): """ All-done trigger rule failure """ @@ -876,15 +937,12 @@ class TestTriggerRuleDep: ) EmptyOperator(task_id="OtherFakeTeakID", dag=ti.task.dag) >> ti.task # An unfinished upstream. - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have completed, but found 1", ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_none_skipped_tr_success(self, session, get_task_instance, flag_upstream_failed): @@ -901,17 +959,12 @@ class TestTriggerRuleDep: done=3, normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), - session=session, - ) - ) - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) - @pytest.mark.parametrize("flag_upstream_failed", [True, False]) - def test_none_skipped_tr_failure(self, session, get_task_instance, flag_upstream_failed): + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) + def test_none_skipped_tr_failure( + self, session, get_task_instance, flag_upstream_failed, expected_ti_state + ): """ None-skipped trigger rule failure """ @@ -925,17 +978,16 @@ class TestTriggerRuleDep: done=2, normal_tasks=["FakeTaskID", "SkippedTaskID"], ) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_ti_state=expected_ti_state, + expected_reason="requires all upstream tasks to not have been skipped, but found 1", ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_none_skipped_tr_failure_empty(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_none_skipped_tr_failure_empty(self, session, get_task_instance, flag_upstream_failed): """ None-skipped trigger rule fails until all upstream tasks have completed execution """ @@ -950,17 +1002,15 @@ class TestTriggerRuleDep: ) EmptyOperator(task_id="FakeTeakID", dag=ti.task.dag) >> ti.task # An unfinished upstream. - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to not have been skipped, but found 0", ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed - def test_unknown_tr(self, session, get_task_instance): + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) + def test_unknown_tr(self, session, get_task_instance, flag_upstream_failed): """ Unknown trigger rules should cause this dep to fail """ @@ -975,15 +1025,12 @@ class TestTriggerRuleDep: ) ti.task.trigger_rule = "Unknown Trigger Rule" - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="No strategy to evaluate trigger rule 'Unknown Trigger Rule'.", ) - assert len(dep_statuses) == 1 - assert not dep_statuses[0].passed def test_UpstreamTIStates(self, session, dag_maker): """ @@ -1002,11 +1049,11 @@ class TestTriggerRuleDep: dr = dag_maker.create_dagrun() tis = {ti.task_id: ti for ti in dr.task_instances} - tis["op1"].state = TaskInstanceState.SUCCESS - tis["op2"].state = TaskInstanceState.FAILED - tis["op3"].state = TaskInstanceState.SUCCESS - tis["op4"].state = TaskInstanceState.SUCCESS - tis["op5"].state = TaskInstanceState.SUCCESS + tis["op1"].state = SUCCESS + tis["op2"].state = FAILED + tis["op3"].state = SUCCESS + tis["op4"].state = SUCCESS + tis["op5"].state = SUCCESS def _get_finished_tis(task_id: str) -> Iterator[TaskInstance]: return (ti for ti in tis.values() if ti.task_id in tis[task_id].task.upstream_task_ids) @@ -1019,16 +1066,19 @@ class TestTriggerRuleDep: dr.update_state(session=session) assert dr.state == DagRunState.SUCCESS + @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", [(True, REMOVED), (False, None)]) def test_mapped_task_upstream_removed_with_all_success_trigger_rules( self, monkeypatch, session, get_mapped_task_dagrun, + flag_upstream_failed, + expected_ti_state, ): """ Test ALL_SUCCESS trigger rule with mapped task upstream removed """ - dr, task = get_mapped_task_dagrun() + dr, task, _ = get_mapped_task_dagrun() # ti with removed upstream ti ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) @@ -1046,28 +1096,26 @@ class TestTriggerRuleDep: ) monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - # Marks the task as removed if upstream is removed. - dep_context=DepContext(flag_upstream_failed=True), - session=session, - ) + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_ti_state=expected_ti_state, ) - assert len(dep_statuses) == 0 - assert ti.state == TaskInstanceState.REMOVED + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_mapped_task_upstream_removed_with_all_failed_trigger_rules( self, monkeypatch, session, get_mapped_task_dagrun, + flag_upstream_failed, ): """ Test ALL_FAILED trigger rule with mapped task upstream removed """ - dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=TaskInstanceState.FAILED) + dr, task, _ = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=FAILED) # ti with removed upstream ti ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) @@ -1085,31 +1133,24 @@ class TestTriggerRuleDep: ) monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) @pytest.mark.parametrize( - "trigger_rule", - [TriggerRule.NONE_FAILED, TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS], + "trigger_rule", [TriggerRule.NONE_FAILED, TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS] ) + @pytest.mark.parametrize("flag_upstream_failed", [True, False]) def test_mapped_task_upstream_removed_with_none_failed_trigger_rules( self, monkeypatch, session, get_mapped_task_dagrun, trigger_rule, + flag_upstream_failed, ): """ Test NONE_FAILED trigger rule with mapped task upstream removed """ - dr, task = get_mapped_task_dagrun(trigger_rule=trigger_rule) + dr, task, _ = get_mapped_task_dagrun(trigger_rule=trigger_rule) # ti with removed upstream ti ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) @@ -1127,15 +1168,7 @@ class TestTriggerRuleDep: ) monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) - dep_statuses = tuple( - TriggerRuleDep()._evaluate_trigger_rule( - ti=ti, - dep_context=DepContext(flag_upstream_failed=False), - session=session, - ) - ) - - assert len(dep_statuses) == 0 + _test_trigger_rule(ti=ti, session=session, flag_upstream_failed=flag_upstream_failed) def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session): @@ -1148,12 +1181,18 @@ def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session): return x @task_group - def tg(x): + def tg1(x): t1 = t.override(task_id="t1")(x=x) return t.override(task_id="t2")(x=t1) - t2 = tg.expand(x=[1, 2, 3]) - t.override(task_id="t3")(x=t2) + t2 = tg1.expand(x=[1, 2, 3]) + + @task_group + def tg2(x): + return t.override(task_id="t3")(x=t2) + + vals2 = tg2.expand(x=[4, 5, 6]) + t.override(task_id="t4")(x=vals2) dr: DagRun = dag_maker.create_dagrun() @@ -1163,29 +1202,40 @@ def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session): # Initial decision. tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)] + assert sorted(tis) == [("tg1.t1", 0), ("tg1.t1", 1), ("tg1.t1", 2)] # After running the first t1, the first t2 becomes immediately available. - tis["tg.t1", 0].run() + tis["tg1.t1", 0].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)] + assert sorted(tis) == [("tg1.t1", 1), ("tg1.t1", 2), ("tg1.t2", 0)] # Similarly for the subsequent t2 instances. - tis["tg.t1", 2].run() + tis["tg1.t1", 2].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)] + assert sorted(tis) == [("tg1.t1", 1), ("tg1.t2", 0), ("tg1.t2", 2)] # But running t2 partially does not make t3 available. - tis["tg.t1", 1].run() - tis["tg.t2", 0].run() - tis["tg.t2", 2].run() + tis["tg1.t1", 1].run() + tis["tg1.t2", 0].run() + tis["tg1.t2", 2].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t2", 1)] + assert sorted(tis) == [("tg1.t2", 1)] # Only after all t2 instances are run does t3 become available. - tis["tg.t2", 1].run() + tis["tg1.t2", 1].run() + tis = _one_scheduling_decision_iteration() + assert sorted(tis) == [("tg2.t3", 0), ("tg2.t3", 1), ("tg2.t3", 2)] + + # But running t3 partially does not make t4 available. + tis["tg2.t3", 0].run() + tis["tg2.t3", 2].run() + tis = _one_scheduling_decision_iteration() + assert sorted(tis) == [("tg2.t3", 1)] + + # Only after all t3 instances are run does t4 become available. + tis["tg2.t3", 1].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("t3", -1)] + assert sorted(tis) == [("t4", -1)] def test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, session): @@ -1216,7 +1266,12 @@ def test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, ses assert tis == {} -def test_mapped_task_check_before_expand(dag_maker, session): [email protected]("flag_upstream_failed", [True, False]) +def test_mapped_task_check_before_expand(dag_maker, session, flag_upstream_failed): + """ + t3 depends on t2, which depends on t1 for expansion. Since t1 has not yet run, t2 has not expanded yet, + and we need to guarantee this lack of expansion does not fail the dependency-checking logic. + """ with dag_maker(session=session): @task @@ -1232,17 +1287,46 @@ def test_mapped_task_check_before_expand(dag_maker, session): tg.expand(a=t([1, 2, 3])) dr: DagRun = dag_maker.create_dagrun() - result_iterator = TriggerRuleDep()._evaluate_trigger_rule( - # t3 depends on t2, which depends on t1 for expansion. Since t1 has not - # yet run, t2 has not expanded yet, and we need to guarantee this lack - # of expansion does not fail the dependency-checking logic. + + _test_trigger_rule( ti=next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and ti.map_index == -1), - dep_context=DepContext(), session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have succeeded, but found 1", + ) + + [email protected]("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) +def test_mapped_task_group_finished_upstream_before_expand( + dag_maker, session, flag_upstream_failed, expected_ti_state +): + """ + t3 depends on t2, which was skipped before it was expanded. We need to guarantee this lack of expansion + does not fail the dependency-checking logic. + """ + with dag_maker(session=session): + + @task + def t(x): + return x + + @task_group + def tg(x): + return t.override(task_id="t3")(x=x) + + t.override(task_id="t2").expand(x=t.override(task_id="t1")([1, 2])) >> tg.expand(x=[1, 2]) + + dr: DagRun = dag_maker.create_dagrun() + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + tis["t2"].set_state(SKIPPED, session=session) + session.flush() + _test_trigger_rule( + ti=tis["tg.t3"], + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="requires all upstream tasks to have succeeded, but found 1", + expected_ti_state=expected_ti_state, ) - results = list(result_iterator) - assert len(results) == 1 - assert results[0].passed is False class TestTriggerRuleDepSetupConstraint: @@ -1407,3 +1491,155 @@ class TestTriggerRuleDepSetupConstraint: (status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session) assert status.reason.startswith("All setup tasks must complete successfully") assert self.get_ti(dr, "w2").state == expected + + [email protected]( + "map_index, flag_upstream_failed, expected_ti_state", + [(2, True, None), (3, True, REMOVED), (4, True, REMOVED), (3, False, None)], +) +def test_setup_constraint_mapped_task_upstream_removed_and_success( + dag_maker, + session, + get_mapped_task_dagrun, + map_index, + flag_upstream_failed, + expected_ti_state, +): + """ + Dynamically mapped setup task with successful and removed upstream tasks. Expect rule to be + successful. State is set to REMOVED for map index >= n success + """ + dr, _, setup_task = get_mapped_task_dagrun(add_setup_tasks=True) + + ti = dr.get_task_instance(task_id="setup_3", map_index=map_index, session=session) + ti.task = setup_task + + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_ti_state=expected_ti_state, + ) + + [email protected]( + "flag_upstream_failed, wait_for_past_depends_before_skipping, past_depends_met, expected_ti_state, expect_failure", + [ + (False, True, True, None, False), + (False, True, False, None, False), + (False, False, False, None, False), + (False, False, True, None, False), + (True, False, False, SKIPPED, False), + (True, False, True, SKIPPED, False), + (True, True, False, None, True), + (True, True, True, SKIPPED, False), + ], +) +def test_setup_constraint_wait_for_past_depends_before_skipping( + dag_maker, + session, + get_task_instance, + monkeypatch, + flag_upstream_failed, + wait_for_past_depends_before_skipping, + past_depends_met, + expected_ti_state, + expect_failure, +): + """ + Setup task with a skipped upstream task. + * If flag_upstream_failed is False then do not expect either a failure nor a modified state. + * If flag_upstream_failed is True and wait_for_past_depends_before_skipping is False then expect the + state to be set to SKIPPED but no failure. + * If both flag_upstream_failed and wait_for_past_depends_before_skipping are True then if the past + depends are met the state is expected to be SKIPPED and no failure, otherwise the state is not + expected to change but the trigger rule should fail. + """ + ti = get_task_instance( + trigger_rule=TriggerRule.ALL_DONE, + success=1, + skipped=1, + failed=0, + removed=0, + upstream_failed=0, + done=2, + setup_tasks=["FakeTaskID", "OtherFakeTaskID"], + ) + + ti.task.xcom_pull.return_value = None + xcom_mock = Mock(return_value=True if past_depends_met else None) + with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", xcom_mock): + _test_trigger_rule( + ti=ti, + session=session, + flag_upstream_failed=flag_upstream_failed, + wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, + expected_ti_state=expected_ti_state, + expected_reason=( + "Task should be skipped but the past depends are not met" if expect_failure else "" + ), + ) + + [email protected]("flag_upstream_failed, expected_ti_state", [(True, SKIPPED), (False, None)]) +def test_setup_mapped_task_group_finished_upstream_before_expand( + dag_maker, session, flag_upstream_failed, expected_ti_state +): + """ + t3 indirectly depends on t1, which was skipped before it was expanded. We need to guarantee this lack of + expansion does not fail the dependency-checking logic. + """ + with dag_maker(session=session): + + @task(trigger_rule=TriggerRule.ALL_DONE) + def t(x): + return x + + @task_group + def tg(x): + return t.override(task_id="t3")(x=x) + + vals = t.override(task_id="t1")([1, 2]).as_setup() + t.override(task_id="t2").expand(x=vals).as_setup() >> tg.expand(x=[1, 2]).as_setup() + + dr: DagRun = dag_maker.create_dagrun() + + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + tis["t1"].set_state(SKIPPED, session=session) + tis["t2"].set_state(SUCCESS, session=session) + session.flush() + _test_trigger_rule( + ti=tis["tg.t3"], + session=session, + flag_upstream_failed=flag_upstream_failed, + expected_reason="All setup tasks must complete successfully.", + expected_ti_state=expected_ti_state, + ) + + +def _test_trigger_rule( + ti: TaskInstance, + session: Session, + flag_upstream_failed: bool, + wait_for_past_depends_before_skipping: bool = False, + expected_reason: str = "", + expected_ti_state: TaskInstanceState | None = None, +) -> None: + assert ti.state is None + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + dep_context=DepContext( + flag_upstream_failed=flag_upstream_failed, + wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping, + ), + session=session, + ) + ) + if expected_reason: + assert len(dep_statuses) == 1 + assert not dep_statuses[0].passed + assert expected_reason in dep_statuses[0].reason + else: + assert not dep_statuses + assert ti.state == expected_ti_state
