This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 007c4b16595 Improve dag_maker compatibility handling (#44125)
007c4b16595 is described below
commit 007c4b16595cb30870e2b64dc105dd1fcff2ef46
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Nov 18 18:27:40 2024 +0800
Improve dag_maker compatibility handling (#44125)
---
.../tests/amazon/aws/operators/test_base_aws.py | 7 +-
.../tests/amazon/aws/sensors/test_base_aws.py | 7 +-
.../kubernetes/operators/test_spark_kubernetes.py | 10 +-
providers/tests/sftp/operators/test_sftp.py | 97 +++++-----------
providers/tests/standard/operators/test_bash.py | 32 ++----
.../tests/standard/operators/test_datetime.py | 43 ++-----
providers/tests/standard/operators/test_python.py | 33 ++----
providers/tests/standard/operators/test_weekday.py | 111 +++++-------------
tests_common/pytest_plugin.py | 126 ++++++++++-----------
9 files changed, 143 insertions(+), 323 deletions(-)
diff --git a/providers/tests/amazon/aws/operators/test_base_aws.py
b/providers/tests/amazon/aws/operators/test_base_aws.py
index d5526b7436e..e95748f0989 100644
--- a/providers/tests/amazon/aws/operators/test_base_aws.py
+++ b/providers/tests/amazon/aws/operators/test_base_aws.py
@@ -25,8 +25,6 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.utils import timezone
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
TEST_CONN = "aws_test_conn"
@@ -118,10 +116,7 @@ class TestAwsBaseOperator:
with dag_maker("test_aws_base_operator", serialized=True):
FakeS3Operator(task_id="fake-task-id", **op_kwargs)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
- else:
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
tis["fake-task-id"].run()
diff --git a/providers/tests/amazon/aws/sensors/test_base_aws.py
b/providers/tests/amazon/aws/sensors/test_base_aws.py
index 08717bb2c05..435f80c2aed 100644
--- a/providers/tests/amazon/aws/sensors/test_base_aws.py
+++ b/providers/tests/amazon/aws/sensors/test_base_aws.py
@@ -25,8 +25,6 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.utils import timezone
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
TEST_CONN = "aws_test_conn"
@@ -120,10 +118,7 @@ class TestAwsBaseSensor:
with dag_maker("test_aws_base_sensor", serialized=True):
FakeDynamoDBSensor(task_id="fake-task-id", **op_kwargs,
poke_interval=1)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
- else:
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
tis["fake-task-id"].run()
diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
index 7f9b743f3eb..784b2c80c38 100644
--- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -733,10 +733,7 @@ def test_resolve_application_file_template_file(dag_maker,
tmp_path, session):
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)
- if AIRFLOW_V_3_0_PLUS:
- ti =
dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
- else:
- ti =
dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
+ ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
session.add(ti)
session.commit()
ti.render_templates()
@@ -776,10 +773,7 @@ def
test_resolve_application_file_template_non_dictionary(dag_maker, tmp_path, b
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)
- if AIRFLOW_V_3_0_PLUS:
- ti =
dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
- else:
- ti =
dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
+ ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
session.add(ti)
session.commit()
ti.render_templates()
diff --git a/providers/tests/sftp/operators/test_sftp.py
b/providers/tests/sftp/operators/test_sftp.py
index b1551e9c456..96d43145cc1 100644
--- a/providers/tests/sftp/operators/test_sftp.py
+++ b/providers/tests/sftp/operators/test_sftp.py
@@ -36,7 +36,6 @@ from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils import timezone
from airflow.utils.timezone import datetime
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.config import conf_vars
pytestmark = pytest.mark.db_test
@@ -184,10 +183,7 @@ class TestSFTPOperator:
command=f"cat {self.test_remote_filepath_int_dir}",
do_xcom_push=True,
)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
- else:
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"):
tis["test_sftp"].run()
@@ -220,10 +216,7 @@ class TestSFTPOperator:
command=f"cat {self.test_remote_filepath}",
do_xcom_push=True,
)
- if AIRFLOW_V_3_0_PLUS:
- dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
- else:
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"):
tis["put_test_task"].run()
@@ -249,18 +242,11 @@ class TestSFTPOperator:
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
)
- if AIRFLOW_V_3_0_PLUS:
- for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
- else:
- for ti in
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
+ for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+ with pytest.warns(
+ AirflowProviderDeprecationWarning, match="Parameter `ssh_hook`
is deprecated..*"
+ ):
+ ti.run()
# Test the received content.
with open(self.test_local_filepath, "rb") as file:
@@ -277,18 +263,11 @@ class TestSFTPOperator:
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
)
- if AIRFLOW_V_3_0_PLUS:
- for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
- else:
- for ti in
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
+ for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+ with pytest.warns(
+ AirflowProviderDeprecationWarning, match="Parameter `ssh_hook`
is deprecated..*"
+ ):
+ ti.run()
# Test the received content.
content_received = None
@@ -307,30 +286,17 @@ class TestSFTPOperator:
operation=SFTPOperation.GET,
)
- if AIRFLOW_V_3_0_PLUS:
- for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
- # This should raise an error with "No such file" as the
directory
- # does not exist.
- with (
- pytest.raises(AirflowException) as ctx,
- pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ),
- ):
- ti.run()
- assert "No such file" in str(ctx.value)
- else:
- for ti in
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
- # This should raise an error with "No such file" as the
directory
- # does not exist.
- with (
- pytest.raises(AirflowException) as ctx,
- pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ),
- ):
- ti.run()
- assert "No such file" in str(ctx.value)
+ for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+ # This should raise an error with "No such file" as the directory
+ # does not exist.
+ with (
+ pytest.raises(AirflowException) as ctx,
+ pytest.warns(
+ AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
+ ),
+ ):
+ ti.run()
+ assert "No such file" in str(ctx.value)
@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker,
create_remote_file_and_cleanup):
@@ -344,18 +310,11 @@ class TestSFTPOperator:
create_intermediate_dirs=True,
)
- if AIRFLOW_V_3_0_PLUS:
- for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
- else:
- for ti in
dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
- with pytest.warns(
- AirflowProviderDeprecationWarning, match="Parameter
`ssh_hook` is deprecated..*"
- ):
- ti.run()
+ for ti in
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
+ with pytest.warns(
+ AirflowProviderDeprecationWarning, match="Parameter `ssh_hook`
is deprecated..*"
+ ):
+ ti.run()
# Test the received content.
content_received = None
diff --git a/providers/tests/standard/operators/test_bash.py
b/providers/tests/standard/operators/test_bash.py
index d548204e931..9b299c2423f 100644
--- a/providers/tests/standard/operators/test_bash.py
+++ b/providers/tests/standard/operators/test_bash.py
@@ -36,9 +36,6 @@ from airflow.utils.types import DagRunType
from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.types import DagRunTriggeredByType
-
if TYPE_CHECKING:
from airflow.models import TaskInstance
@@ -111,27 +108,14 @@ class TestBashOperator:
)
logical_date = utc_now
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
- dag_maker.create_dagrun(
- run_type=DagRunType.MANUAL,
- logical_date=logical_date,
- start_date=utc_now,
- state=State.RUNNING,
- external_trigger=False,
- data_interval=(logical_date, logical_date),
- **triggered_by_kwargs,
- )
- else:
- dag_maker.create_dagrun(
- run_type=DagRunType.MANUAL,
- execution_date=logical_date,
- start_date=utc_now,
- state=State.RUNNING,
- external_trigger=False,
- data_interval=(logical_date, logical_date),
- **triggered_by_kwargs,
- )
+ dag_maker.create_dagrun(
+ run_type=DagRunType.MANUAL,
+ logical_date=logical_date,
+ start_date=utc_now,
+ state=State.RUNNING,
+ external_trigger=False,
+ data_interval=(logical_date, logical_date),
+ )
with mock.patch.dict(
"os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME",
"PYTHONPATH": "AWESOME_PYTHONPATH"}
diff --git a/providers/tests/standard/operators/test_datetime.py
b/providers/tests/standard/operators/test_datetime.py
index b51821275dc..67f72f2c6e2 100644
--- a/providers/tests/standard/operators/test_datetime.py
+++ b/providers/tests/standard/operators/test_datetime.py
@@ -31,11 +31,6 @@ from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.types import DagRunTriggeredByType
-
pytestmark = pytest.mark.db_test
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -79,24 +74,12 @@ class TestBranchDateTimeOperator:
self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
self.dr = dag_maker.create_dagrun(
run_id="manual__",
start_date=DEFAULT_DATE,
logical_date=DEFAULT_DATE,
state=State.RUNNING,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
- else:
- self.dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=DEFAULT_DATE,
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
)
def teardown_method(self):
@@ -251,25 +234,13 @@ class TestBranchDateTimeOperator:
"""Check if BranchDateTimeOperator uses task logical date"""
in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
self.branch_op.use_task_logical_date = True
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
- self.dr = dag_maker.create_dagrun(
- run_id="manual_exec_date__",
- start_date=in_between_date,
- logical_date=in_between_date,
- state=State.RUNNING,
- data_interval=(in_between_date, in_between_date),
- **triggered_by_kwargs,
- )
- else:
- self.dr = dag_maker.create_dagrun(
- run_id="manual_exec_date__",
- start_date=in_between_date,
- execution_date=in_between_date,
- state=State.RUNNING,
- data_interval=(in_between_date, in_between_date),
- **triggered_by_kwargs,
- )
+ self.dr = dag_maker.create_dagrun(
+ run_id="manual_exec_date__",
+ start_date=in_between_date,
+ logical_date=in_between_date,
+ state=State.RUNNING,
+ data_interval=(in_between_date, in_between_date),
+ )
self.branch_op.target_lower = target_lower
self.branch_op.target_upper = target_upper
diff --git a/providers/tests/standard/operators/test_python.py
b/providers/tests/standard/operators/test_python.py
index 3798a712384..efeade9ca09 100644
--- a/providers/tests/standard/operators/test_python.py
+++ b/providers/tests/standard/operators/test_python.py
@@ -75,10 +75,6 @@ from tests_common.test_utils import AIRFLOW_MAIN_FOLDER
from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS,
AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.db import clear_db_runs
-if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.types import DagRunTriggeredByType
-
-
if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
@@ -148,27 +144,14 @@ class BasePythonTest:
return kwargs
def create_dag_run(self) -> DagRun:
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
- return self.dag_maker.create_dagrun(
- state=DagRunState.RUNNING,
- start_date=self.dag_maker.start_date,
- session=self.dag_maker.session,
- logical_date=self.default_date,
- run_type=DagRunType.MANUAL,
- data_interval=(self.default_date, self.default_date),
- **triggered_by_kwargs, # type: ignore
- )
- else:
- return self.dag_maker.create_dagrun(
- state=DagRunState.RUNNING,
- start_date=self.dag_maker.start_date,
- session=self.dag_maker.session,
- execution_date=self.default_date,
- run_type=DagRunType.MANUAL,
- data_interval=(self.default_date, self.default_date),
- **triggered_by_kwargs, # type: ignore
- )
+ return self.dag_maker.create_dagrun(
+ state=DagRunState.RUNNING,
+ start_date=self.dag_maker.start_date,
+ session=self.dag_maker.session,
+ logical_date=self.default_date,
+ run_type=DagRunType.MANUAL,
+ data_interval=(self.default_date, self.default_date),
+ )
def create_ti(self, fn, **kwargs) -> TI:
"""Create TaskInstance for class defined Operator."""
diff --git a/providers/tests/standard/operators/test_weekday.py
b/providers/tests/standard/operators/test_weekday.py
index a9eb19189ec..1a182e9bbfc 100644
--- a/providers/tests/standard/operators/test_weekday.py
+++ b/providers/tests/standard/operators/test_weekday.py
@@ -33,11 +33,6 @@ from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.weekday import WeekDay
-from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.types import DagRunTriggeredByType
-
pytestmark = pytest.mark.db_test
DEFAULT_DATE = timezone.datetime(2020, 2, 5) # Wednesday
@@ -105,25 +100,13 @@ class TestBranchDayOfWeekOperator:
branch_2.set_upstream(branch_op)
branch_3.set_upstream(branch_op)
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- logical_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
- else:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
+ dr = dag_maker.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ logical_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ )
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -155,26 +138,13 @@ class TestBranchDayOfWeekOperator:
branch_1.set_upstream(branch_op)
branch_2.set_upstream(branch_op)
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
-
- if AIRFLOW_V_3_0_PLUS:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- logical_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
- else:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
+ dr = dag_maker.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ logical_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ )
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -204,26 +174,13 @@ class TestBranchDayOfWeekOperator:
branch_1.set_upstream(branch_op)
branch_2.set_upstream(branch_op)
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
-
- if AIRFLOW_V_3_0_PLUS:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- logical_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
- else:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
+ dr = dag_maker.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ logical_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ )
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -314,25 +271,13 @@ class TestBranchDayOfWeekOperator:
branch_1.set_upstream(branch_op)
branch_2.set_upstream(branch_op)
- triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {}
- if AIRFLOW_V_3_0_PLUS:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- logical_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
- else:
- dr = dag_maker.create_dagrun(
- run_id="manual__",
- start_date=timezone.utcnow(),
- execution_date=DEFAULT_DATE,
- state=State.RUNNING,
- data_interval=(DEFAULT_DATE, DEFAULT_DATE),
- **triggered_by_kwargs,
- )
+ dr = dag_maker.create_dagrun(
+ run_id="manual__",
+ start_date=timezone.utcnow(),
+ logical_date=DEFAULT_DATE,
+ state=State.RUNNING,
+ data_interval=(DEFAULT_DATE, DEFAULT_DATE),
+ )
branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py
index 20479300166..c29cd1ea657 100644
--- a/tests_common/pytest_plugin.py
+++ b/tests_common/pytest_plugin.py
@@ -727,7 +727,14 @@ class DagMaker(Protocol):
def get_serialized_data(self) -> dict[str, Any]: ...
- def create_dagrun(self, **kwargs) -> DagRun: ...
+ def create_dagrun(
+ self,
+ run_id: str = ...,
+ logical_date: datetime = ...,
+ data_interval: DataInterval = ...,
+ run_type: DagRunType = ...,
+ **kwargs,
+ ) -> DagRun: ...
def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ...
@@ -890,17 +897,20 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
else:
self._bag_dag_compat(self.dag)
- def create_dagrun(self, **kwargs):
+ def create_dagrun(self, *, logical_date=None, **kwargs):
from airflow.utils import timezone
- from airflow.utils.state import State
+ from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
+ if "execution_date" in kwargs:
+ raise TypeError("use logical_date instead")
+
dag = self.dag
kwargs = {
- "state": State.RUNNING,
+ "state": DagRunState.RUNNING,
"start_date": self.start_date,
"session": self.session,
**kwargs,
@@ -912,31 +922,27 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
if "run_type" not in kwargs:
kwargs["run_type"] = DagRunType.from_run_id(kwargs["run_id"])
- if AIRFLOW_V_3_0_PLUS:
- if kwargs.get("logical_date") is None:
- if kwargs["run_type"] == DagRunType.MANUAL:
- kwargs["logical_date"] = self.start_date
- else:
- kwargs["logical_date"] =
dag.next_dagrun_info(None).logical_date
- else:
- if kwargs.get("execution_date") is None:
- if kwargs["run_type"] == DagRunType.MANUAL:
- kwargs["execution_date"] = self.start_date
- else:
- kwargs["execution_date"] =
dag.next_dagrun_info(None).logical_date
+
+ if logical_date is None:
+ if kwargs["run_type"] == DagRunType.MANUAL:
+ logical_date = self.start_date
+ else:
+ logical_date = dag.next_dagrun_info(None).logical_date
+ logical_date = timezone.coerce_datetime(logical_date)
+
if "data_interval" not in kwargs:
- logical_date = (
- timezone.coerce_datetime(kwargs["logical_date"])
- if AIRFLOW_V_3_0_PLUS
- else timezone.coerce_datetime(kwargs["execution_date"])
- )
if kwargs["run_type"] == DagRunType.MANUAL:
data_interval =
dag.timetable.infer_manual_data_interval(run_after=logical_date)
else:
data_interval =
dag.infer_automated_data_interval(logical_date)
kwargs["data_interval"] = data_interval
- if AIRFLOW_V_3_0_PLUS and "triggered_by" not in kwargs:
- kwargs["triggered_by"] = DagRunTriggeredByType.TEST
+
+ if AIRFLOW_V_3_0_PLUS:
+ kwargs.setdefault("triggered_by", DagRunTriggeredByType.TEST)
+ kwargs["logical_date"] = logical_date
+ else:
+ kwargs.pop("triggered_by", None)
+ kwargs["execution_date"] = logical_date
self.dag_run = dag.create_dagrun(**kwargs)
for ti in self.dag_run.task_instances:
@@ -949,18 +955,10 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
next_info =
self.dag.next_dagrun_info(self.dag.get_run_data_interval(dagrun))
if next_info is None:
raise ValueError(f"cannot create run after {dagrun}")
- return (
- self.create_dagrun(
- logical_date=next_info.logical_date,
- data_interval=next_info.data_interval,
- **kwargs,
- )
- if AIRFLOW_V_3_0_PLUS
- else self.create_dagrun(
- execution_date=next_info.logical_date,
- data_interval=next_info.data_interval,
- **kwargs,
- )
+ return self.create_dagrun(
+ logical_date=next_info.logical_date,
+ data_interval=next_info.data_interval,
+ **kwargs,
)
def __call__(
@@ -1199,7 +1197,7 @@ def create_task_instance(dag_maker: DagMaker,
create_dummy_dag: CreateDummyDAG)
"""
from airflow.operators.empty import EmptyOperator
- from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+ from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS
def maker(
logical_date=None,
@@ -1228,17 +1226,12 @@ def create_task_instance(dag_maker: DagMaker,
create_dummy_dag: CreateDummyDAG)
last_heartbeat_at=None,
**kwargs,
) -> TaskInstance:
- if AIRFLOW_V_3_0_PLUS:
- from airflow.utils.types import DagRunTriggeredByType
-
if logical_date is None:
from airflow.utils import timezone
logical_date = timezone.utcnow()
with dag_maker(dag_id, **kwargs):
op_kwargs = {}
- from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS
-
if AIRFLOW_V_2_9_PLUS:
op_kwargs["task_display_name"] = task_display_name
task = EmptyOperator(
@@ -1255,12 +1248,10 @@ def create_task_instance(dag_maker: DagMaker,
create_dummy_dag: CreateDummyDAG)
trigger_rule=trigger_rule,
**op_kwargs,
)
- date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
dagrun_kwargs = {
- date_key: logical_date,
+ "logical_date": logical_date,
"state": dagrun_state,
}
- dagrun_kwargs.update({"triggered_by": DagRunTriggeredByType.TEST} if
AIRFLOW_V_3_0_PLUS else {})
if run_id is not None:
dagrun_kwargs["run_id"] = run_id
if run_type is not None:
@@ -1282,10 +1273,22 @@ def create_task_instance(dag_maker: DagMaker,
create_dummy_dag: CreateDummyDAG)
return maker
[email protected]
-def create_serialized_task_instance_of_operator(dag_maker: DagMaker):
- from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
+class CreateTaskInstanceOfOperator(Protocol):
+ """Type stub for create_task_instance_of_operator and
create_serialized_task_instance_of_operator."""
+ def __call__(
+ self,
+ operator_class: type[BaseOperator],
+ *,
+ dag_id: str,
+ logical_date: datetime = ...,
+ session: Session = ...,
+ **kwargs,
+ ) -> TaskInstance: ...
+
+
[email protected]
+def create_serialized_task_instance_of_operator(dag_maker: DagMaker) ->
CreateTaskInstanceOfOperator:
def _create_task_instance(
operator_class,
*,
@@ -1296,26 +1299,14 @@ def
create_serialized_task_instance_of_operator(dag_maker: DagMaker):
) -> TaskInstance:
with dag_maker(dag_id=dag_id, serialized=True, session=session):
operator_class(**operator_kwargs)
- if logical_date is None:
- dagrun_kwargs = {}
- else:
- dagrun_kwargs = {"logical_date" if AIRFLOW_V_3_0_PLUS else
"execution_date": logical_date}
- (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+ (ti,) =
dag_maker.create_dagrun(logical_date=logical_date).task_instances
return ti
return _create_task_instance
-class CreateTaskInstanceOfOperator(Protocol):
- """Type stub for create_task_instance_of_operator."""
-
- def __call__(self, operator_class: type[BaseOperator], *args, **kwargs) ->
TaskInstance: ...
-
-
@pytest.fixture
def create_task_instance_of_operator(dag_maker: DagMaker) ->
CreateTaskInstanceOfOperator:
- from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
-
def _create_task_instance(
operator_class,
*,
@@ -1326,11 +1317,7 @@ def create_task_instance_of_operator(dag_maker:
DagMaker) -> CreateTaskInstanceO
) -> TaskInstance:
with dag_maker(dag_id=dag_id, session=session, serialized=True):
operator_class(**operator_kwargs)
- if logical_date is None:
- dagrun_kwargs = {}
- else:
- dagrun_kwargs = {"logical_date" if AIRFLOW_V_3_0_PLUS else
"execution_date": logical_date}
- (ti,) = dag_maker.create_dagrun(**dagrun_kwargs).task_instances
+ (ti,) =
dag_maker.create_dagrun(logical_date=logical_date).task_instances
return ti
return _create_task_instance
@@ -1339,7 +1326,14 @@ def create_task_instance_of_operator(dag_maker:
DagMaker) -> CreateTaskInstanceO
class CreateTaskOfOperator(Protocol):
"""Type stub for create_task_of_operator."""
- def __call__(self, operator_class: type[Op], *args, **kwargs) -> Op: ...
+ def __call__(
+ self,
+ operator_class: type[Op],
+ *,
+ dag_id: str,
+ session: Session = ...,
+ **kwargs,
+ ) -> Op: ...
@pytest.fixture