uranusjr commented on code in PR #41277:
URL: https://github.com/apache/airflow/pull/41277#discussion_r1705212652


##########
tests/operators/test_branch_operator.py:
##########
@@ -127,113 +120,152 @@ def test_branch_list_without_dag_run(self):
                 else:
                     raise Exception
 
-    def test_with_dag_run(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
+    def test_with_dag_run(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            else:
-                raise Exception
-
-    def test_with_skip_in_branch_downstream_dependencies(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-        self.branch_op >> self.branch_1 >> self.branch_2
-        self.branch_op >> self.branch_2
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        with create_session() as session:
+            tis = session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE)
+
+            expected = {
+                "make_choice": State.SUCCESS,
+                "branch_1": State.NONE,
+                "branch_2": State.SKIPPED,
+            }
+
+            for ti in tis:
+                if ti.task_id in expected:
+                    assert ti.state == expected[ti.task_id]
+                else:
+                    raise Exception
+
+    def test_with_skip_in_branch_downstream_dependencies(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_op >> branch_1 >> branch_2
+            branch_op >> branch_2
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id in ("branch_1", "branch_2"):
-                assert ti.state == State.NONE
-            else:
-                raise Exception
-
-    def test_xcom_push(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
+        with create_session() as session:
+            tis = session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE)
+            expected = {
+                "make_choice": State.SUCCESS,
+                "branch_1": State.NONE,
+                "branch_2": State.NONE,
+            }
 
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
+            for ti in tis:
+                if ti.task_id in expected:
+                    assert ti.state == expected[ti.task_id]
+                else:
+                    raise Exception
 
-        dr = self.dag.create_dagrun(
+    def test_xcom_push(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-    def test_with_dag_run_task_groups(self):
-        self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
-        self.branch_3 = TaskGroup("branch_3", dag=self.dag)
-        _ = EmptyOperator(task_id="task_1", dag=self.dag, 
task_group=self.branch_3)
-        _ = EmptyOperator(task_id="task_2", dag=self.dag, 
task_group=self.branch_3)
-
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.branch_3.set_upstream(self.branch_op)
-
-        self.dag.clear()
+        with create_session() as session:
+            tis = session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE)
+            for ti in tis:
+                if ti.task_id == "make_choice":
+                    assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
 
-        dagrun = self.dag.create_dagrun(
+    def test_with_dag_run_task_groups(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_3 = TaskGroup("branch_3")
+            EmptyOperator(task_id="task_1", task_group=branch_3)
+            EmptyOperator(task_id="task_2", task_group=branch_3)
+            branch_op = ChooseBranchThree(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+            branch_3.set_upstream(branch_op)
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.SKIPPED
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
-            elif ti.task_id == "branch_3.task_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_3.task_2":
-                assert ti.state == State.NONE
-            else:
-                raise Exception
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        with create_session() as session:

Review Comment:
   There’s a `session` fixture. Can we use it? `dag_maker` also keeps a session 
internally. Maybe that can be used instead?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to