This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 007c4b16595 Improve dag_maker compatibility handling (#44125)
007c4b16595 is described below

commit 007c4b16595cb30870e2b64dc105dd1fcff2ef46
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Nov 18 18:27:40 2024 +0800

    Improve dag_maker compatibility handling (#44125)
---
 .../tests/amazon/aws/operators/test_base_aws.py    |   7 +-
 .../tests/amazon/aws/sensors/test_base_aws.py      |   7 +-
 .../kubernetes/operators/test_spark_kubernetes.py  |  10 +-
 providers/tests/sftp/operators/test_sftp.py        |  97 +++++-----------
 providers/tests/standard/operators/test_bash.py    |  32 ++----
 .../tests/standard/operators/test_datetime.py      |  43 ++-----
 providers/tests/standard/operators/test_python.py  |  33 ++----
 providers/tests/standard/operators/test_weekday.py | 111 +++++-------------
 tests_common/pytest_plugin.py                      | 126 ++++++++++-----------
 9 files changed, 143 insertions(+), 323 deletions(-)

diff --git a/providers/tests/amazon/aws/operators/test_base_aws.py 
b/providers/tests/amazon/aws/operators/test_base_aws.py
index d5526b7436e..e95748f0989 100644
--- a/providers/tests/amazon/aws/operators/test_base_aws.py
+++ b/providers/tests/amazon/aws/operators/test_base_aws.py
@@ -25,8 +25,6 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.utils import timezone
 
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
 TEST_CONN = "aws_test_conn"
 
 
@@ -118,10 +116,7 @@ class TestAwsBaseOperator:
         with dag_maker("test_aws_base_operator", serialized=True):
             FakeS3Operator(task_id="fake-task-id", **op_kwargs)
 
-        if AIRFLOW_V_3_0_PLUS:
-            dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
-        else:
-            dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
         tis = {ti.task_id: ti for ti in dagrun.task_instances}
         tis["fake-task-id"].run()
 
diff --git a/providers/tests/amazon/aws/sensors/test_base_aws.py 
b/providers/tests/amazon/aws/sensors/test_base_aws.py
index 08717bb2c05..435f80c2aed 100644
--- a/providers/tests/amazon/aws/sensors/test_base_aws.py
+++ b/providers/tests/amazon/aws/sensors/test_base_aws.py
@@ -25,8 +25,6 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
 from airflow.utils import timezone
 
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
 TEST_CONN = "aws_test_conn"
 
 
@@ -120,10 +118,7 @@ class TestAwsBaseSensor:
         with dag_maker("test_aws_base_sensor", serialized=True):
             FakeDynamoDBSensor(task_id="fake-task-id", **op_kwargs, 
poke_interval=1)
 
-        if AIRFLOW_V_3_0_PLUS:
-            dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
-        else:
-            dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
         tis = {ti.task_id: ti for ti in dagrun.task_instances}
         tis["fake-task-id"].run()
 
diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py 
b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7f9b743f3eb..784b2c80c38 100644
--- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -733,10 +733,7 @@ def test_resolve_application_file_template_file(dag_maker, 
tmp_path, session):
             kubernetes_conn_id="kubernetes_default_kube_config",
             task_id="test_template_body_templating_task",
         )
-    if AIRFLOW_V_3_0_PLUS:
-        ti = 
dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
-    else:
-        ti = 
dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
+    ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
     session.add(ti)
     session.commit()
     ti.render_templates()
@@ -776,10 +773,7 @@ def 
test_resolve_application_file_template_non_dictionary(dag_maker, tmp_path, b
             kubernetes_conn_id="kubernetes_default_kube_config",
             task_id="test_template_body_templating_task",
         )
-    if AIRFLOW_V_3_0_PLUS:
-        ti = 
dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
-    else:
-        ti = 
dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
+    ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
     session.add(ti)
     session.commit()
     ti.render_templates()
diff --git a/providers/tests/sftp/operators/test_sftp.py 
b/providers/tests/sftp/operators/test_sftp.py
index b1551e9c456..96d43145cc1 100644
--- a/providers/tests/sftp/operators/test_sftp.py
+++ b/providers/tests/sftp/operators/test_sftp.py
@@ -36,7 +36,6 @@ from airflow.providers.ssh.operators.ssh import SSHOperator
 from airflow.utils import timezone
 from airflow.utils.timezone import datetime
 
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
 from tests_common.test_utils.config import conf_vars
 
 pytestmark = pytest.mark.db_test
@@ -184,10 +183,7 @@ class TestSFTPOperator:
                 command=f"cat {self.test_remote_filepath_int_dir}",
                 do_xcom_push=True,
             )
-        if AIRFLOW_V_3_0_PLUS:
-            dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
-        else:
-            dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
         tis = {ti.task_id: ti for ti in dagrun.task_instances}
         with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"):
             tis["test_sftp"].run()
@@ -220,10 +216,7 @@ class TestSFTPOperator:
                 command=f"cat {self.test_remote_filepath}",
                 do_xcom_push=True,
             )
-        if AIRFLOW_V_3_0_PLUS:
-            dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
-        else:
-            dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
         tis = {ti.task_id: ti for ti in dagrun.task_instances}
         with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"):
             tis["put_test_task"].run()
@@ -249,18 +242,11 @@ class TestSFTPOperator:
                 remote_filepath=self.test_remote_filepath,
                 operation=SFTPOperation.GET,
             )
-        if AIRFLOW_V_3_0_PLUS:
-            for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
-        else:
-            for ti in 
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
+        for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+            with pytest.warns(
+                AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` 
is deprecated..*"
+            ):
+                ti.run()
 
         # Test the received content.
         with open(self.test_local_filepath, "rb") as file:
@@ -277,18 +263,11 @@ class TestSFTPOperator:
                 remote_filepath=self.test_remote_filepath,
                 operation=SFTPOperation.GET,
             )
-        if AIRFLOW_V_3_0_PLUS:
-            for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
-        else:
-            for ti in 
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
+        for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+            with pytest.warns(
+                AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` 
is deprecated..*"
+            ):
+                ti.run()
 
         # Test the received content.
         content_received = None
@@ -307,30 +286,17 @@ class TestSFTPOperator:
                 operation=SFTPOperation.GET,
             )
 
-        if AIRFLOW_V_3_0_PLUS:
-            for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
-                # This should raise an error with "No such file" as the 
directory
-                # does not exist.
-                with (
-                    pytest.raises(AirflowException) as ctx,
-                    pytest.warns(
-                        AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                    ),
-                ):
-                    ti.run()
-                assert "No such file" in str(ctx.value)
-        else:
-            for ti in 
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
-                # This should raise an error with "No such file" as the 
directory
-                # does not exist.
-                with (
-                    pytest.raises(AirflowException) as ctx,
-                    pytest.warns(
-                        AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                    ),
-                ):
-                    ti.run()
-                assert "No such file" in str(ctx.value)
+        for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+            # This should raise an error with "No such file" as the directory
+            # does not exist.
+            with (
+                pytest.raises(AirflowException) as ctx,
+                pytest.warns(
+                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
+                ),
+            ):
+                ti.run()
+            assert "No such file" in str(ctx.value)
 
     @conf_vars({("core", "enable_xcom_pickling"): "True"})
     def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker, 
create_remote_file_and_cleanup):
@@ -344,18 +310,11 @@ class TestSFTPOperator:
                 create_intermediate_dirs=True,
             )
 
-        if AIRFLOW_V_3_0_PLUS:
-            for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
-        else:
-            for ti in 
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
-                with pytest.warns(
-                    AirflowProviderDeprecationWarning, match="Parameter 
`ssh_hook` is deprecated..*"
-                ):
-                    ti.run()
+        for ti in 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+            with pytest.warns(
+                AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` 
is deprecated..*"
+            ):
+                ti.run()
 
         # Test the received content.
         content_received = None
diff --git a/providers/tests/standard/operators/test_bash.py 
b/providers/tests/standard/operators/test_bash.py
index d548204e931..9b299c2423f 100644
--- a/providers/tests/standard/operators/test_bash.py
+++ b/providers/tests/standard/operators/test_bash.py
@@ -36,9 +36,6 @@ from airflow.utils.types import DagRunType
 
 from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
 
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.utils.types import DagRunTriggeredByType
-
 if TYPE_CHECKING:
     from airflow.models import TaskInstance
 
@@ -111,27 +108,14 @@ class TestBashOperator:
             )
 
         logical_date = utc_now
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
-            dag_maker.create_dagrun(
-                run_type=DagRunType.MANUAL,
-                logical_date=logical_date,
-                start_date=utc_now,
-                state=State.RUNNING,
-                external_trigger=False,
-                data_interval=(logical_date, logical_date),
-                **triggered_by_kwargs,
-            )
-        else:
-            dag_maker.create_dagrun(
-                run_type=DagRunType.MANUAL,
-                execution_date=logical_date,
-                start_date=utc_now,
-                state=State.RUNNING,
-                external_trigger=False,
-                data_interval=(logical_date, logical_date),
-                **triggered_by_kwargs,
-            )
+        dag_maker.create_dagrun(
+            run_type=DagRunType.MANUAL,
+            logical_date=logical_date,
+            start_date=utc_now,
+            state=State.RUNNING,
+            external_trigger=False,
+            data_interval=(logical_date, logical_date),
+        )
 
         with mock.patch.dict(
             "os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME", 
"PYTHONPATH": "AWESOME_PYTHONPATH"}
diff --git a/providers/tests/standard/operators/test_datetime.py 
b/providers/tests/standard/operators/test_datetime.py
index b51821275dc..67f72f2c6e2 100644
--- a/providers/tests/standard/operators/test_datetime.py
+++ b/providers/tests/standard/operators/test_datetime.py
@@ -31,11 +31,6 @@ from airflow.utils import timezone
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.utils.types import DagRunTriggeredByType
-
 pytestmark = pytest.mark.db_test
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -79,24 +74,12 @@ class TestBranchDateTimeOperator:
             self.branch_1.set_upstream(self.branch_op)
             self.branch_2.set_upstream(self.branch_op)
 
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
             self.dr = dag_maker.create_dagrun(
                 run_id="manual__",
                 start_date=DEFAULT_DATE,
                 logical_date=DEFAULT_DATE,
                 state=State.RUNNING,
                 data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
-        else:
-            self.dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=DEFAULT_DATE,
-                execution_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
             )
 
     def teardown_method(self):
@@ -251,25 +234,13 @@ class TestBranchDateTimeOperator:
         """Check if BranchDateTimeOperator uses task logical date"""
         in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
         self.branch_op.use_task_logical_date = True
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
-            self.dr = dag_maker.create_dagrun(
-                run_id="manual_exec_date__",
-                start_date=in_between_date,
-                logical_date=in_between_date,
-                state=State.RUNNING,
-                data_interval=(in_between_date, in_between_date),
-                **triggered_by_kwargs,
-            )
-        else:
-            self.dr = dag_maker.create_dagrun(
-                run_id="manual_exec_date__",
-                start_date=in_between_date,
-                execution_date=in_between_date,
-                state=State.RUNNING,
-                data_interval=(in_between_date, in_between_date),
-                **triggered_by_kwargs,
-            )
+        self.dr = dag_maker.create_dagrun(
+            run_id="manual_exec_date__",
+            start_date=in_between_date,
+            logical_date=in_between_date,
+            state=State.RUNNING,
+            data_interval=(in_between_date, in_between_date),
+        )
 
         self.branch_op.target_lower = target_lower
         self.branch_op.target_upper = target_upper
diff --git a/providers/tests/standard/operators/test_python.py 
b/providers/tests/standard/operators/test_python.py
index 3798a712384..efeade9ca09 100644
--- a/providers/tests/standard/operators/test_python.py
+++ b/providers/tests/standard/operators/test_python.py
@@ -75,10 +75,6 @@ from tests_common.test_utils import AIRFLOW_MAIN_FOLDER
 from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS, 
AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
 from tests_common.test_utils.db import clear_db_runs
 
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.utils.types import DagRunTriggeredByType
-
-
 if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun
 
@@ -148,27 +144,14 @@ class BasePythonTest:
         return kwargs
 
     def create_dag_run(self) -> DagRun:
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
-            return self.dag_maker.create_dagrun(
-                state=DagRunState.RUNNING,
-                start_date=self.dag_maker.start_date,
-                session=self.dag_maker.session,
-                logical_date=self.default_date,
-                run_type=DagRunType.MANUAL,
-                data_interval=(self.default_date, self.default_date),
-                **triggered_by_kwargs,  # type: ignore
-            )
-        else:
-            return self.dag_maker.create_dagrun(
-                state=DagRunState.RUNNING,
-                start_date=self.dag_maker.start_date,
-                session=self.dag_maker.session,
-                execution_date=self.default_date,
-                run_type=DagRunType.MANUAL,
-                data_interval=(self.default_date, self.default_date),
-                **triggered_by_kwargs,  # type: ignore
-            )
+        return self.dag_maker.create_dagrun(
+            state=DagRunState.RUNNING,
+            start_date=self.dag_maker.start_date,
+            session=self.dag_maker.session,
+            logical_date=self.default_date,
+            run_type=DagRunType.MANUAL,
+            data_interval=(self.default_date, self.default_date),
+        )
 
     def create_ti(self, fn, **kwargs) -> TI:
         """Create TaskInstance for class defined Operator."""
diff --git a/providers/tests/standard/operators/test_weekday.py 
b/providers/tests/standard/operators/test_weekday.py
index a9eb19189ec..1a182e9bbfc 100644
--- a/providers/tests/standard/operators/test_weekday.py
+++ b/providers/tests/standard/operators/test_weekday.py
@@ -33,11 +33,6 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.weekday import WeekDay
 
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
-    from airflow.utils.types import DagRunTriggeredByType
-
 pytestmark = pytest.mark.db_test
 
 DEFAULT_DATE = timezone.datetime(2020, 2, 5)  # Wednesday
@@ -105,25 +100,13 @@ class TestBranchDayOfWeekOperator:
             branch_2.set_upstream(branch_op)
             branch_3.set_upstream(branch_op)
 
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                logical_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
-        else:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                execution_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
+        dr = dag_maker.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            logical_date=DEFAULT_DATE,
+            state=State.RUNNING,
+            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+        )
 
         branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
@@ -155,26 +138,13 @@ class TestBranchDayOfWeekOperator:
             branch_1.set_upstream(branch_op)
             branch_2.set_upstream(branch_op)
 
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-
-        if AIRFLOW_V_3_0_PLUS:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                logical_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
-        else:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                execution_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
+        dr = dag_maker.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            logical_date=DEFAULT_DATE,
+            state=State.RUNNING,
+            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+        )
 
         branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
@@ -204,26 +174,13 @@ class TestBranchDayOfWeekOperator:
             branch_1.set_upstream(branch_op)
             branch_2.set_upstream(branch_op)
 
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-
-        if AIRFLOW_V_3_0_PLUS:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                logical_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
-        else:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                execution_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
+        dr = dag_maker.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            logical_date=DEFAULT_DATE,
+            state=State.RUNNING,
+            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+        )
 
         branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
@@ -314,25 +271,13 @@ class TestBranchDayOfWeekOperator:
             branch_1.set_upstream(branch_op)
             branch_2.set_upstream(branch_op)
 
-        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
-        if AIRFLOW_V_3_0_PLUS:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                logical_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
-        else:
-            dr = dag_maker.create_dagrun(
-                run_id="manual__",
-                start_date=timezone.utcnow(),
-                execution_date=DEFAULT_DATE,
-                state=State.RUNNING,
-                data_interval=(DEFAULT_DATE, DEFAULT_DATE),
-                **triggered_by_kwargs,
-            )
+        dr = dag_maker.create_dagrun(
+            run_id="manual__",
+            start_date=timezone.utcnow(),
+            logical_date=DEFAULT_DATE,
+            state=State.RUNNING,
+            data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+        )
 
         branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index 20479300166..c29cd1ea657 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -727,7 +727,14 @@ class DagMaker(Protocol):
 
     def get_serialized_data(self) -> dict[str, Any]: ...
 
-    def create_dagrun(self, **kwargs) -> DagRun: ...
+    def create_dagrun(
+        self,
+        run_id: str = ...,
+        logical_date: datetime = ...,
+        data_interval: DataInterval = ...,
+        run_type: DagRunType = ...,
+        **kwargs,
+    ) -> DagRun: ...
 
     def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ...
 
@@ -890,17 +897,20 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
             else:
                 self._bag_dag_compat(self.dag)
 
-        def create_dagrun(self, **kwargs):
+        def create_dagrun(self, *, logical_date=None, **kwargs):
             from airflow.utils import timezone
-            from airflow.utils.state import State
+            from airflow.utils.state import DagRunState
             from airflow.utils.types import DagRunType
 
             if AIRFLOW_V_3_0_PLUS:
                 from airflow.utils.types import DagRunTriggeredByType
 
+            if "execution_date" in kwargs:
+                raise TypeError("use logical_date instead")
+
             dag = self.dag
             kwargs = {
-                "state": State.RUNNING,
+                "state": DagRunState.RUNNING,
                 "start_date": self.start_date,
                 "session": self.session,
                 **kwargs,
@@ -912,31 +922,27 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
 
             if "run_type" not in kwargs:
                 kwargs["run_type"] = DagRunType.from_run_id(kwargs["run_id"])
-            if AIRFLOW_V_3_0_PLUS:
-                if kwargs.get("logical_date") is None:
-                    if kwargs["run_type"] == DagRunType.MANUAL:
-                        kwargs["logical_date"] = self.start_date
-                    else:
-                        kwargs["logical_date"] = 
dag.next_dagrun_info(None).logical_date
-            else:
-                if kwargs.get("execution_date") is None:
-                    if kwargs["run_type"] == DagRunType.MANUAL:
-                        kwargs["execution_date"] = self.start_date
-                    else:
-                        kwargs["execution_date"] = 
dag.next_dagrun_info(None).logical_date
+
+            if logical_date is None:
+                if kwargs["run_type"] == DagRunType.MANUAL:
+                    logical_date = self.start_date
+                else:
+                    logical_date = dag.next_dagrun_info(None).logical_date
+            logical_date = timezone.coerce_datetime(logical_date)
+
             if "data_interval" not in kwargs:
-                logical_date = (
-                    timezone.coerce_datetime(kwargs["logical_date"])
-                    if AIRFLOW_V_3_0_PLUS
-                    else timezone.coerce_datetime(kwargs["execution_date"])
-                )
                 if kwargs["run_type"] == DagRunType.MANUAL:
                     data_interval = 
dag.timetable.infer_manual_data_interval(run_after=logical_date)
                 else:
                     data_interval = 
dag.infer_automated_data_interval(logical_date)
                 kwargs["data_interval"] = data_interval
-            if AIRFLOW_V_3_0_PLUS and "triggered_by" not in kwargs:
-                kwargs["triggered_by"] = DagRunTriggeredByType.TEST
+
+            if AIRFLOW_V_3_0_PLUS:
+                kwargs.setdefault("triggered_by", DagRunTriggeredByType.TEST)
+                kwargs["logical_date"] = logical_date
+            else:
+                kwargs.pop("triggered_by", None)
+                kwargs["execution_date"] = logical_date
 
             self.dag_run = dag.create_dagrun(**kwargs)
             for ti in self.dag_run.task_instances:
@@ -949,18 +955,10 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
             next_info = 
self.dag.next_dagrun_info(self.dag.get_run_data_interval(dagrun))
             if next_info is None:
                 raise ValueError(f"cannot create run after {dagrun}")
-            return (
-                self.create_dagrun(
-                    logical_date=next_info.logical_date,
-                    data_interval=next_info.data_interval,
-                    **kwargs,
-                )
-                if AIRFLOW_V_3_0_PLUS
-                else self.create_dagrun(
-                    execution_date=next_info.logical_date,
-                    data_interval=next_info.data_interval,
-                    **kwargs,
-                )
+            return self.create_dagrun(
+                logical_date=next_info.logical_date,
+                data_interval=next_info.data_interval,
+                **kwargs,
             )
 
         def __call__(
@@ -1199,7 +1197,7 @@ def create_task_instance(dag_maker: DagMaker, 
create_dummy_dag: CreateDummyDAG)
     """
     from airflow.operators.empty import EmptyOperator
 
-    from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+    from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS
 
     def maker(
         logical_date=None,
@@ -1228,17 +1226,12 @@ def create_task_instance(dag_maker: DagMaker, 
create_dummy_dag: CreateDummyDAG)
         last_heartbeat_at=None,
         **kwargs,
     ) -> TaskInstance:
-        if AIRFLOW_V_3_0_PLUS:
-            from airflow.utils.types import DagRunTriggeredByType
-
         if logical_date is None:
             from airflow.utils import timezone
 
             logical_date = timezone.utcnow()
         with dag_maker(dag_id, **kwargs):
             op_kwargs = {}
-            from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS
-
             if AIRFLOW_V_2_9_PLUS:
                 op_kwargs["task_display_name"] = task_display_name
             task = EmptyOperator(
@@ -1255,12 +1248,10 @@ def create_task_instance(dag_maker: DagMaker, 
create_dummy_dag: CreateDummyDAG)
                 trigger_rule=trigger_rule,
                 **op_kwargs,
             )
-        date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
         dagrun_kwargs = {
-            date_key: logical_date,
+            "logical_date": logical_date,
             "state": dagrun_state,
         }
-        dagrun_kwargs.update({"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {})
         if run_id is not None:
             dagrun_kwargs["run_id"] = run_id
         if run_type is not None:
@@ -1282,10 +1273,22 @@ def create_task_instance(dag_maker: DagMaker, 
create_dummy_dag: CreateDummyDAG)
     return maker
 
 
[email protected]
-def create_serialized_task_instance_of_operator(dag_maker: DagMaker):
-    from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+class CreateTaskInstanceOfOperator(Protocol):
+    """Type stub for create_task_instance_of_operator and 
create_serialized_task_instance_of_operator."""
 
+    def __call__(
+        self,
+        operator_class: type[BaseOperator],
+        *,
+        dag_id: str,
+        logical_date: datetime = ...,
+        session: Session = ...,
+        **kwargs,
+    ) -> TaskInstance: ...
+
+
[email protected]
+def create_serialized_task_instance_of_operator(dag_maker: DagMaker) -> 
CreateTaskInstanceOfOperator:
     def _create_task_instance(
         operator_class,
         *,
@@ -1296,26 +1299,14 @@ def 
create_serialized_task_instance_of_operator(dag_maker: DagMaker):
     ) -> TaskInstance:
         with dag_maker(dag_id=dag_id, serialized=True, session=session):
             operator_class(**operator_kwargs)
-        if logical_date is None:
-            dagrun_kwargs = {}
-        else:
-            dagrun_kwargs = {"logical_date" if AIRFLOW_V_3_0_PLUS else 
"execution_date": logical_date}
-        (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+        (ti,) = 
dag_maker.create_dagrun(logical_date=logical_date).task_instances
         return ti
 
     return _create_task_instance
 
 
-class CreateTaskInstanceOfOperator(Protocol):
-    """Type stub for create_task_instance_of_operator."""
-
-    def __call__(self, operator_class: type[BaseOperator], *args, **kwargs) -> 
TaskInstance: ...
-
-
 @pytest.fixture
 def create_task_instance_of_operator(dag_maker: DagMaker) -> 
CreateTaskInstanceOfOperator:
-    from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
     def _create_task_instance(
         operator_class,
         *,
@@ -1326,11 +1317,7 @@ def create_task_instance_of_operator(dag_maker: 
DagMaker) -> CreateTaskInstanceO
     ) -> TaskInstance:
         with dag_maker(dag_id=dag_id, session=session, serialized=True):
             operator_class(**operator_kwargs)
-        if logical_date is None:
-            dagrun_kwargs = {}
-        else:
-            dagrun_kwargs = {"logical_date" if AIRFLOW_V_3_0_PLUS else 
"execution_date": logical_date}
-        (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+        (ti,) = 
dag_maker.create_dagrun(logical_date=logical_date).task_instances
         return ti
 
     return _create_task_instance
@@ -1339,7 +1326,14 @@ def create_task_instance_of_operator(dag_maker: 
DagMaker) -> CreateTaskInstanceO
 class CreateTaskOfOperator(Protocol):
     """Type stub for create_task_of_operator."""
 
-    def __call__(self, operator_class: type[Op], *args, **kwargs) -> Op: ...
+    def __call__(
+        self,
+        operator_class: type[Op],
+        *,
+        dag_id: str,
+        session: Session = ...,
+        **kwargs,
+    ) -> Op: ...
 
 
 @pytest.fixture


Reply via email to