kazanzhy commented on a change in pull request #20044:
URL: https://github.com/apache/airflow/pull/20044#discussion_r762954844



##########
File path: tests/operators/test_python.py
##########
@@ -580,128 +582,164 @@ def test_raise_exception_on_invalid_task_id(self):
 class TestShortCircuitOperator(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
-        super().setUpClass()
-
         with create_session() as session:
             session.query(DagRun).delete()
             session.query(TI).delete()
 
-    def tearDown(self):
-        super().tearDown()
+    def setUp(self):
+        self.dag = DAG(
+            "short_circuit_op_test",
+            start_date=DEFAULT_DATE,
+            schedule_interval=INTERVAL,
+        )
 
+        with self.dag:
+            self.op1 = DummyOperator(task_id="op1")
+            self.op2 = DummyOperator(task_id="op2")
+            self.op1.set_downstream(self.op2)
+
+    def tearDown(self):
         with create_session() as session:
             session.query(DagRun).delete()
             session.query(TI).delete()
 
-    def test_with_dag_run(self):
-        value = False
-        dag = DAG(
-            'shortcircuit_operator_test_with_dag_run',
-            default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
-            schedule_interval=INTERVAL,
+    def _assert_expected_task_states(self, dagrun, expected_states):
+        """Helper function that asserts `TaskInstances` of a given `task_id` 
are in a given state."""
+
+        tis = dagrun.get_task_instances()
+        for ti in tis:
+            try:
+                expected_state = expected_states[ti.task_id]
+            except KeyError:
+                raise ValueError(f"Invalid task id {ti.task_id} found!")
+            else:
+                assert ti.state == expected_state
+
+    all_downstream_skipped_states = {
+        "short_circuit": State.SUCCESS,
+        "op1": State.SKIPPED,
+        "op2": State.SKIPPED,
+    }
+    all_success_states = {"short_circuit": State.SUCCESS, "op1": 
State.SUCCESS, "op2": State.SUCCESS}
+
+    @parameterized.expand(
+        [
+            # Skip downstream tasks, do not respect trigger rules, default 
trigger rule on all downstream
+            # tasks
+            (False, True, TriggerRule.ALL_SUCCESS, 
all_downstream_skipped_states),
+            # Skip downstream tasks, do not respect trigger rules, non-default 
trigger rule on a downstream
+            # task
+            (False, True, TriggerRule.ALL_DONE, all_downstream_skipped_states),
+            # Skip downstream tasks, respect trigger rules, default trigger 
rule on all downstream tasks
+            (
+                False,
+                False,
+                TriggerRule.ALL_SUCCESS,
+                {"short_circuit": State.SUCCESS, "op1": State.SKIPPED, "op2": 
State.NONE},
+            ),
+            # Skip downstream tasks, respect trigger rules, non-default 
trigger rule on a downstream task
+            (
+                False,
+                False,
+                TriggerRule.ALL_DONE,
+                {"short_circuit": State.SUCCESS, "op1": State.SKIPPED, "op2": 
State.SUCCESS},
+            ),
+            # Do not skip downstream tasks, do not respect trigger rules, 
default trigger rule on all
+            # downstream tasks
+            (True, True, TriggerRule.ALL_SUCCESS, all_success_states),
+            # Do not skip downstream tasks, do not respect trigger rules, 
non-default trigger rule on a
+            # downstream task
+            (True, True, TriggerRule.ALL_DONE, all_success_states),
+            # Do not skip downstream tasks, respect trigger rules, default 
trigger rule on all downstream
+            # tasks
+            (True, False, TriggerRule.ALL_SUCCESS, all_success_states),
+            # Do not skip downstream tasks, respect trigger rules, non-default 
trigger rule on a downstream
+            # task
+            (True, False, TriggerRule.ALL_DONE, all_success_states),
+        ],
+    )
+    def test_short_circuiting(
+        self, callable_return, test_ignore_downstream_trigger_rules, 
test_trigger_rule, expected_task_states
+    ):
+        """
+        Checking the behavior of the ShortCircuitOperator in several scenarios 
enabling/disabling the skipping
+        of downstream tasks, both short-circuiting modes, and various trigger 
rules of downstream tasks.
+        """
+
+        self.short_circuit = ShortCircuitOperator(
+            task_id="short_circuit",
+            python_callable=lambda: callable_return,
+            
ignore_downstream_trigger_rules=test_ignore_downstream_trigger_rules,
+            dag=self.dag,
         )
-        short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, 
python_callable=lambda: value)
-        branch_1 = DummyOperator(task_id='branch_1', dag=dag)
-        branch_1.set_upstream(short_op)
-        branch_2 = DummyOperator(task_id='branch_2', dag=dag)
-        branch_2.set_upstream(branch_1)
-        upstream = DummyOperator(task_id='upstream', dag=dag)
-        upstream.set_downstream(short_op)
-        dag.clear()
-
-        logging.error("Tasks %s", dag.tasks)
-        dr = dag.create_dagrun(
+        self.short_circuit.set_downstream(self.op1)
+        self.op2.trigger_rule = test_trigger_rule
+        self.dag.clear()
+
+        dagrun = self.dag.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
         )
 
-        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        assert len(tis) == 4
-        for ti in tis:
-            if ti.task_id == 'make_choice':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'upstream':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
-                assert ti.state == State.SKIPPED
-            else:
-                raise ValueError(f'Invalid task id {ti.task_id} found!')
+        self.short_circuit.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        value = True
-        dag.clear()
-        dr.verify_integrity()
-        upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-        short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        assert self.short_circuit.ignore_downstream_trigger_rules == 
test_ignore_downstream_trigger_rules
+        assert self.short_circuit.trigger_rule == TriggerRule.ALL_SUCCESS
+        assert self.op1.trigger_rule == TriggerRule.ALL_SUCCESS
+        assert self.op2.trigger_rule == test_trigger_rule
 
-        tis = dr.get_task_instances()
-        assert len(tis) == 4
-        for ti in tis:
-            if ti.task_id == 'make_choice':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'upstream':
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
-                assert ti.state == State.NONE
-            else:
-                raise ValueError(f'Invalid task id {ti.task_id} found!')
+        self._assert_expected_task_states(dagrun, expected_task_states)
 
     def test_clear_skipped_downstream_task(self):
         """
         After a downstream task is skipped by ShortCircuitOperator, clearing 
the skipped task
         should not cause it to be executed.

Review comment:
       ```suggestion
           After a downstream task with the "all_success" trigger rule is 
skipped by ShortCircuitOperator, 
           clearing the skipped task should not cause it to be executed.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to