This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 30f776693e Fix tests/operators/test_bash.py and branch_operator.py for 
Database Isolation Mode (#41277)
30f776693e is described below

commit 30f776693e8b0a6497e21ea72baaf3b36418ba34
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Aug 6 19:59:38 2024 +0200

    Fix tests/operators/test_bash.py and branch_operator.py for Database 
Isolation Mode (#41277)
    
    * Fix tests/operators/test_bash.py and branch_operator.py for Database 
Isolation Tests
    
    * Review Feedback
---
 tests/conftest.py                       |   2 +-
 tests/operators/test_bash.py            |  43 ++---
 tests/operators/test_branch_operator.py | 274 +++++++++++++++++---------------
 3 files changed, 168 insertions(+), 151 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 11b47363b0..5601461409 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1152,7 +1152,7 @@ def create_task_instance_of_operator(dag_maker):
         session=None,
         **operator_kwargs,
     ) -> TaskInstance:
-        with dag_maker(dag_id=dag_id, session=session):
+        with dag_maker(dag_id=dag_id, session=session, serialized=True):
             operator_class(**operator_kwargs)
         if execution_date is None:
             dagrun_kwargs = {}
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 4ef4973f08..f63dd0dea2 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -28,7 +28,6 @@ from unittest import mock
 import pytest
 
 from airflow.exceptions import AirflowException, AirflowSkipException, 
AirflowTaskTimeout
-from airflow.models.dag import DAG
 from airflow.operators.bash import BashOperator
 from airflow.utils import timezone
 from airflow.utils.state import State
@@ -65,7 +64,9 @@ class TestBashOperator:
             (True, {"AIRFLOW_HOME": "OVERRIDDEN_AIRFLOW_HOME"}, 
"OVERRIDDEN_AIRFLOW_HOME"),
         ],
     )
-    def test_echo_env_variables(self, append_env, user_defined_env, 
expected_airflow_home, tmp_path):
+    def test_echo_env_variables(
+        self, append_env, user_defined_env, expected_airflow_home, dag_maker, 
tmp_path
+    ):
         """
         Test that env variables are exported correctly to the task bash 
environment.
         """
@@ -79,15 +80,28 @@ class TestBashOperator:
             f"manual__{utc_now.isoformat()}\n"
         )
 
-        dag = DAG(
-            dag_id="bash_op_test",
+        with dag_maker(
+            "bash_op_test",
             default_args={"owner": "airflow", "retries": 100, "start_date": 
DEFAULT_DATE},
             schedule="@daily",
             dagrun_timeout=timedelta(minutes=60),
-        )
+            serialized=True,
+        ):
+            tmp_file = tmp_path / "testfile"
+            task = BashOperator(
+                task_id="echo_env_vars",
+                bash_command=f"echo $AIRFLOW_HOME>> {tmp_file};"
+                f"echo $PYTHONPATH>> {tmp_file};"
+                f"echo $AIRFLOW_CTX_DAG_ID >> {tmp_file};"
+                f"echo $AIRFLOW_CTX_TASK_ID>> {tmp_file};"
+                f"echo $AIRFLOW_CTX_EXECUTION_DATE>> {tmp_file};"
+                f"echo $AIRFLOW_CTX_DAG_RUN_ID>> {tmp_file};",
+                append_env=append_env,
+                env=user_defined_env,
+            )
 
         execution_date = utc_now
-        dag.create_dagrun(
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             execution_date=execution_date,
             start_date=utc_now,
@@ -96,20 +110,6 @@ class TestBashOperator:
             data_interval=(execution_date, execution_date),
         )
 
-        tmp_file = tmp_path / "testfile"
-        task = BashOperator(
-            task_id="echo_env_vars",
-            dag=dag,
-            bash_command=f"echo $AIRFLOW_HOME>> {tmp_file};"
-            f"echo $PYTHONPATH>> {tmp_file};"
-            f"echo $AIRFLOW_CTX_DAG_ID >> {tmp_file};"
-            f"echo $AIRFLOW_CTX_TASK_ID>> {tmp_file};"
-            f"echo $AIRFLOW_CTX_EXECUTION_DATE>> {tmp_file};"
-            f"echo $AIRFLOW_CTX_DAG_RUN_ID>> {tmp_file};",
-            append_env=append_env,
-            env=user_defined_env,
-        )
-
         with mock.patch.dict(
             "os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME", 
"PYTHONPATH": "AWESOME_PYTHONPATH"}
         ):
@@ -244,12 +244,13 @@ class TestBashOperator:
         import psutil
 
         sleep_time = f"100{os.getpid()}"
-        with dag_maker():
+        with dag_maker(serialized=True):
             op = BashOperator(
                 task_id="test_bash_operator_kill",
                 execution_timeout=timedelta(microseconds=25),
                 bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
             )
+        dag_maker.create_dagrun()
         with pytest.raises(AirflowTaskTimeout):
             op.run()
         sleep(2)
diff --git a/tests/operators/test_branch_operator.py 
b/tests/operators/test_branch_operator.py
index a8c904fe2c..fe475843bc 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -21,13 +21,10 @@ import datetime
 
 import pytest
 
-from airflow.models.dag import DAG
-from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.operators.branch import BaseBranchOperator
 from airflow.operators.empty import EmptyOperator
 from airflow.utils import timezone
-from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.types import DagRunType
@@ -54,86 +51,82 @@ class ChooseBranchThree(BaseBranchOperator):
 
 
 class TestBranchOperator:
-    @classmethod
-    def setup_class(cls):
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
-
-    def setup_method(self):
-        self.dag = DAG(
-            "branch_operator_test",
+    def test_without_dag_run(self, dag_maker):
+        """This checks the defensive against non existent tasks in a dag run"""
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
             default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
             schedule=INTERVAL,
-        )
-
-        self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
-        self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
-        self.branch_3 = None
-        self.branch_op = None
-
-    def teardown_method(self):
-        with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TI).delete()
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+        dag_maker.create_dagrun()
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
+            if ti.task_id == "make_choice":
+                assert ti.state == State.SUCCESS
+            elif ti.task_id == "branch_1":
+                # should exist with state None
+                assert ti.state == State.NONE
+            elif ti.task_id == "branch_2":
+                assert ti.state == State.SKIPPED
+            else:
+                raise Exception
 
-    def test_without_dag_run(self):
-        """This checks the defensive against non existent tasks in a dag run"""
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
-
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        with create_session() as session:
-            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, 
TI.execution_date == DEFAULT_DATE)
-
-            for ti in tis:
-                if ti.task_id == "make_choice":
-                    assert ti.state == State.SUCCESS
-                elif ti.task_id == "branch_1":
-                    # should exist with state None
-                    assert ti.state == State.NONE
-                elif ti.task_id == "branch_2":
-                    assert ti.state == State.SKIPPED
-                else:
-                    raise Exception
-
-    def test_branch_list_without_dag_run(self):
+    def test_branch_list_without_dag_run(self, dag_maker):
         """This checks if the BranchOperator supports branching off to a list 
of tasks."""
-        self.branch_op = ChooseBranchOneTwo(task_id="make_choice", 
dag=self.dag)
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
-        self.branch_3.set_upstream(self.branch_op)
-        self.dag.clear()
-
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
-        with create_session() as session:
-            tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, 
TI.execution_date == DEFAULT_DATE)
-
-            expected = {
-                "make_choice": State.SUCCESS,
-                "branch_1": State.NONE,
-                "branch_2": State.NONE,
-                "branch_3": State.SKIPPED,
-            }
-
-            for ti in tis:
-                if ti.task_id in expected:
-                    assert ti.state == expected[ti.task_id]
-                else:
-                    raise Exception
-
-    def test_with_dag_run(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_3 = EmptyOperator(task_id="branch_3")
+            branch_op = ChooseBranchOneTwo(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+            branch_3.set_upstream(branch_op)
+        dag_maker.create_dagrun()
+
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        expected = {
+            "make_choice": State.SUCCESS,
+            "branch_1": State.NONE,
+            "branch_2": State.NONE,
+            "branch_3": State.SKIPPED,
+        }
+
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
+            if ti.task_id in expected:
+                assert ti.state == expected[ti.task_id]
+            else:
+                raise Exception
+
+    def test_with_dag_run(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
@@ -141,26 +134,35 @@ class TestBranchOperator:
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id == "branch_1":
-                assert ti.state == State.NONE
-            elif ti.task_id == "branch_2":
-                assert ti.state == State.SKIPPED
+        expected = {
+            "make_choice": State.SUCCESS,
+            "branch_1": State.NONE,
+            "branch_2": State.SKIPPED,
+        }
+
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
+            if ti.task_id in expected:
+                assert ti.state == expected[ti.task_id]
             else:
                 raise Exception
 
-    def test_with_skip_in_branch_downstream_dependencies(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-        self.branch_op >> self.branch_1 >> self.branch_2
-        self.branch_op >> self.branch_2
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
+    def test_with_skip_in_branch_downstream_dependencies(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_op >> branch_1 >> branch_2
+            branch_op >> branch_2
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
@@ -168,25 +170,35 @@ class TestBranchOperator:
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        tis = dagrun.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "make_choice":
-                assert ti.state == State.SUCCESS
-            elif ti.task_id in ("branch_1", "branch_2"):
-                assert ti.state == State.NONE
+        expected = {
+            "make_choice": State.SUCCESS,
+            "branch_1": State.NONE,
+            "branch_2": State.NONE,
+        }
+
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
+            if ti.task_id in expected:
+                assert ti.state == expected[ti.task_id]
             else:
                 raise Exception
 
-    def test_xcom_push(self):
-        self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.dag.clear()
-
-        dr = self.dag.create_dagrun(
+    def test_xcom_push(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_op = ChooseBranchOne(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
@@ -194,26 +206,31 @@ class TestBranchOperator:
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        tis = dr.get_task_instances()
-        for ti in tis:
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
             if ti.task_id == "make_choice":
                 assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
 
-    def test_with_dag_run_task_groups(self):
-        self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
-        self.branch_3 = TaskGroup("branch_3", dag=self.dag)
-        _ = EmptyOperator(task_id="task_1", dag=self.dag, 
task_group=self.branch_3)
-        _ = EmptyOperator(task_id="task_2", dag=self.dag, 
task_group=self.branch_3)
-
-        self.branch_1.set_upstream(self.branch_op)
-        self.branch_2.set_upstream(self.branch_op)
-        self.branch_3.set_upstream(self.branch_op)
-
-        self.dag.clear()
-
-        dagrun = self.dag.create_dagrun(
+    def test_with_dag_run_task_groups(self, dag_maker):
+        dag_id = "branch_operator_test"
+        with dag_maker(
+            dag_id,
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule=INTERVAL,
+            serialized=True,
+        ):
+            branch_1 = EmptyOperator(task_id="branch_1")
+            branch_2 = EmptyOperator(task_id="branch_2")
+            branch_3 = TaskGroup("branch_3")
+            EmptyOperator(task_id="task_1", task_group=branch_3)
+            EmptyOperator(task_id="task_2", task_group=branch_3)
+            branch_op = ChooseBranchThree(task_id="make_choice")
+            branch_1.set_upstream(branch_op)
+            branch_2.set_upstream(branch_op)
+            branch_3.set_upstream(branch_op)
+
+        dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             start_date=timezone.utcnow(),
             execution_date=DEFAULT_DATE,
@@ -221,10 +238,9 @@ class TestBranchOperator:
             data_interval=(DEFAULT_DATE, DEFAULT_DATE),
         )
 
-        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
-        tis = dagrun.get_task_instances()
-        for ti in tis:
+        for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id, 
TI.execution_date == DEFAULT_DATE):
             if ti.task_id == "make_choice":
                 assert ti.state == State.SUCCESS
             elif ti.task_id == "branch_1":

Reply via email to