This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 30f776693e Fix tests/operators/test_bash.py and branch_operator.py for
Database Isolation Mode (#41277)
30f776693e is described below
commit 30f776693e8b0a6497e21ea72baaf3b36418ba34
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Aug 6 19:59:38 2024 +0200
Fix tests/operators/test_bash.py and branch_operator.py for Database
Isolation Mode (#41277)
* Fix tests/operators/test_bash.py and branch_operator.py for Database
Isolation Tests
* Review Feedback
---
tests/conftest.py | 2 +-
tests/operators/test_bash.py | 43 ++---
tests/operators/test_branch_operator.py | 274 +++++++++++++++++---------------
3 files changed, 168 insertions(+), 151 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 11b47363b0..5601461409 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1152,7 +1152,7 @@ def create_task_instance_of_operator(dag_maker):
session=None,
**operator_kwargs,
) -> TaskInstance:
- with dag_maker(dag_id=dag_id, session=session):
+ with dag_maker(dag_id=dag_id, session=session, serialized=True):
operator_class(**operator_kwargs)
if execution_date is None:
dagrun_kwargs = {}
diff --git a/tests/operators/test_bash.py b/tests/operators/test_bash.py
index 4ef4973f08..f63dd0dea2 100644
--- a/tests/operators/test_bash.py
+++ b/tests/operators/test_bash.py
@@ -28,7 +28,6 @@ from unittest import mock
import pytest
from airflow.exceptions import AirflowException, AirflowSkipException,
AirflowTaskTimeout
-from airflow.models.dag import DAG
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
from airflow.utils.state import State
@@ -65,7 +64,9 @@ class TestBashOperator:
(True, {"AIRFLOW_HOME": "OVERRIDDEN_AIRFLOW_HOME"},
"OVERRIDDEN_AIRFLOW_HOME"),
],
)
- def test_echo_env_variables(self, append_env, user_defined_env,
expected_airflow_home, tmp_path):
+ def test_echo_env_variables(
+ self, append_env, user_defined_env, expected_airflow_home, dag_maker,
tmp_path
+ ):
"""
Test that env variables are exported correctly to the task bash
environment.
"""
@@ -79,15 +80,28 @@ class TestBashOperator:
f"manual__{utc_now.isoformat()}\n"
)
- dag = DAG(
- dag_id="bash_op_test",
+ with dag_maker(
+ "bash_op_test",
default_args={"owner": "airflow", "retries": 100, "start_date":
DEFAULT_DATE},
schedule="@daily",
dagrun_timeout=timedelta(minutes=60),
- )
+ serialized=True,
+ ):
+ tmp_file = tmp_path / "testfile"
+ task = BashOperator(
+ task_id="echo_env_vars",
+ bash_command=f"echo $AIRFLOW_HOME>> {tmp_file};"
+ f"echo $PYTHONPATH>> {tmp_file};"
+ f"echo $AIRFLOW_CTX_DAG_ID >> {tmp_file};"
+ f"echo $AIRFLOW_CTX_TASK_ID>> {tmp_file};"
+ f"echo $AIRFLOW_CTX_EXECUTION_DATE>> {tmp_file};"
+ f"echo $AIRFLOW_CTX_DAG_RUN_ID>> {tmp_file};",
+ append_env=append_env,
+ env=user_defined_env,
+ )
execution_date = utc_now
- dag.create_dagrun(
+ dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=execution_date,
start_date=utc_now,
@@ -96,20 +110,6 @@ class TestBashOperator:
data_interval=(execution_date, execution_date),
)
- tmp_file = tmp_path / "testfile"
- task = BashOperator(
- task_id="echo_env_vars",
- dag=dag,
- bash_command=f"echo $AIRFLOW_HOME>> {tmp_file};"
- f"echo $PYTHONPATH>> {tmp_file};"
- f"echo $AIRFLOW_CTX_DAG_ID >> {tmp_file};"
- f"echo $AIRFLOW_CTX_TASK_ID>> {tmp_file};"
- f"echo $AIRFLOW_CTX_EXECUTION_DATE>> {tmp_file};"
- f"echo $AIRFLOW_CTX_DAG_RUN_ID>> {tmp_file};",
- append_env=append_env,
- env=user_defined_env,
- )
-
with mock.patch.dict(
"os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME",
"PYTHONPATH": "AWESOME_PYTHONPATH"}
):
@@ -244,12 +244,13 @@ class TestBashOperator:
import psutil
sleep_time = f"100{os.getpid()}"
- with dag_maker():
+ with dag_maker(serialized=True):
op = BashOperator(
task_id="test_bash_operator_kill",
execution_timeout=timedelta(microseconds=25),
bash_command=f"/bin/bash -c 'sleep {sleep_time}'",
)
+ dag_maker.create_dagrun()
with pytest.raises(AirflowTaskTimeout):
op.run()
sleep(2)
diff --git a/tests/operators/test_branch_operator.py
b/tests/operators/test_branch_operator.py
index a8c904fe2c..fe475843bc 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -21,13 +21,10 @@ import datetime
import pytest
-from airflow.models.dag import DAG
-from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.branch import BaseBranchOperator
from airflow.operators.empty import EmptyOperator
from airflow.utils import timezone
-from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
@@ -54,86 +51,82 @@ class ChooseBranchThree(BaseBranchOperator):
class TestBranchOperator:
- @classmethod
- def setup_class(cls):
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
-
- def setup_method(self):
- self.dag = DAG(
- "branch_operator_test",
+ def test_without_dag_run(self, dag_maker):
+ """This checks the defensive against non existent tasks in a dag run"""
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
schedule=INTERVAL,
- )
-
- self.branch_1 = EmptyOperator(task_id="branch_1", dag=self.dag)
- self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag)
- self.branch_3 = None
- self.branch_op = None
-
- def teardown_method(self):
- with create_session() as session:
- session.query(DagRun).delete()
- session.query(TI).delete()
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_op = ChooseBranchOne(task_id="make_choice")
+ branch_1.set_upstream(branch_op)
+ branch_2.set_upstream(branch_op)
+ dag_maker.create_dagrun()
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
+ if ti.task_id == "make_choice":
+ assert ti.state == State.SUCCESS
+ elif ti.task_id == "branch_1":
+ # should exist with state None
+ assert ti.state == State.NONE
+ elif ti.task_id == "branch_2":
+ assert ti.state == State.SKIPPED
+ else:
+ raise Exception
- def test_without_dag_run(self):
- """This checks the defensive against non existent tasks in a dag run"""
- self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
- self.branch_1.set_upstream(self.branch_op)
- self.branch_2.set_upstream(self.branch_op)
- self.dag.clear()
-
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- with create_session() as session:
- tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE)
-
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- # should exist with state None
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
- else:
- raise Exception
-
- def test_branch_list_without_dag_run(self):
+ def test_branch_list_without_dag_run(self, dag_maker):
"""This checks if the BranchOperator supports branching off to a list
of tasks."""
- self.branch_op = ChooseBranchOneTwo(task_id="make_choice",
dag=self.dag)
- self.branch_1.set_upstream(self.branch_op)
- self.branch_2.set_upstream(self.branch_op)
- self.branch_3 = EmptyOperator(task_id="branch_3", dag=self.dag)
- self.branch_3.set_upstream(self.branch_op)
- self.dag.clear()
-
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
-
- with create_session() as session:
- tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE)
-
- expected = {
- "make_choice": State.SUCCESS,
- "branch_1": State.NONE,
- "branch_2": State.NONE,
- "branch_3": State.SKIPPED,
- }
-
- for ti in tis:
- if ti.task_id in expected:
- assert ti.state == expected[ti.task_id]
- else:
- raise Exception
-
- def test_with_dag_run(self):
- self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
- self.branch_1.set_upstream(self.branch_op)
- self.branch_2.set_upstream(self.branch_op)
- self.dag.clear()
-
- dagrun = self.dag.create_dagrun(
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule=INTERVAL,
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_3 = EmptyOperator(task_id="branch_3")
+ branch_op = ChooseBranchOneTwo(task_id="make_choice")
+ branch_1.set_upstream(branch_op)
+ branch_2.set_upstream(branch_op)
+ branch_3.set_upstream(branch_op)
+ dag_maker.create_dagrun()
+
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ expected = {
+ "make_choice": State.SUCCESS,
+ "branch_1": State.NONE,
+ "branch_2": State.NONE,
+ "branch_3": State.SKIPPED,
+ }
+
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
+ if ti.task_id in expected:
+ assert ti.state == expected[ti.task_id]
+ else:
+ raise Exception
+
+ def test_with_dag_run(self, dag_maker):
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule=INTERVAL,
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_op = ChooseBranchOne(task_id="make_choice")
+ branch_1.set_upstream(branch_op)
+ branch_2.set_upstream(branch_op)
+ dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
@@ -141,26 +134,35 @@ class TestBranchOperator:
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
)
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- tis = dagrun.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id == "branch_1":
- assert ti.state == State.NONE
- elif ti.task_id == "branch_2":
- assert ti.state == State.SKIPPED
+ expected = {
+ "make_choice": State.SUCCESS,
+ "branch_1": State.NONE,
+ "branch_2": State.SKIPPED,
+ }
+
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
+ if ti.task_id in expected:
+ assert ti.state == expected[ti.task_id]
else:
raise Exception
- def test_with_skip_in_branch_downstream_dependencies(self):
- self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
- self.branch_op >> self.branch_1 >> self.branch_2
- self.branch_op >> self.branch_2
- self.dag.clear()
-
- dagrun = self.dag.create_dagrun(
+ def test_with_skip_in_branch_downstream_dependencies(self, dag_maker):
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule=INTERVAL,
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_op = ChooseBranchOne(task_id="make_choice")
+ branch_op >> branch_1 >> branch_2
+ branch_op >> branch_2
+
+ dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
@@ -168,25 +170,35 @@ class TestBranchOperator:
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
)
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- tis = dagrun.get_task_instances()
- for ti in tis:
- if ti.task_id == "make_choice":
- assert ti.state == State.SUCCESS
- elif ti.task_id in ("branch_1", "branch_2"):
- assert ti.state == State.NONE
+ expected = {
+ "make_choice": State.SUCCESS,
+ "branch_1": State.NONE,
+ "branch_2": State.NONE,
+ }
+
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
+ if ti.task_id in expected:
+ assert ti.state == expected[ti.task_id]
else:
raise Exception
- def test_xcom_push(self):
- self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
-
- self.branch_1.set_upstream(self.branch_op)
- self.branch_2.set_upstream(self.branch_op)
- self.dag.clear()
-
- dr = self.dag.create_dagrun(
+ def test_xcom_push(self, dag_maker):
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule=INTERVAL,
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_op = ChooseBranchOne(task_id="make_choice")
+ branch_1.set_upstream(branch_op)
+ branch_2.set_upstream(branch_op)
+
+ dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
@@ -194,26 +206,31 @@ class TestBranchOperator:
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
)
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- tis = dr.get_task_instances()
- for ti in tis:
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
if ti.task_id == "make_choice":
assert ti.xcom_pull(task_ids="make_choice") == "branch_1"
- def test_with_dag_run_task_groups(self):
- self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
- self.branch_3 = TaskGroup("branch_3", dag=self.dag)
- _ = EmptyOperator(task_id="task_1", dag=self.dag,
task_group=self.branch_3)
- _ = EmptyOperator(task_id="task_2", dag=self.dag,
task_group=self.branch_3)
-
- self.branch_1.set_upstream(self.branch_op)
- self.branch_2.set_upstream(self.branch_op)
- self.branch_3.set_upstream(self.branch_op)
-
- self.dag.clear()
-
- dagrun = self.dag.create_dagrun(
+ def test_with_dag_run_task_groups(self, dag_maker):
+ dag_id = "branch_operator_test"
+ with dag_maker(
+ dag_id,
+ default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+ schedule=INTERVAL,
+ serialized=True,
+ ):
+ branch_1 = EmptyOperator(task_id="branch_1")
+ branch_2 = EmptyOperator(task_id="branch_2")
+ branch_3 = TaskGroup("branch_3")
+ EmptyOperator(task_id="task_1", task_group=branch_3)
+ EmptyOperator(task_id="task_2", task_group=branch_3)
+ branch_op = ChooseBranchThree(task_id="make_choice")
+ branch_1.set_upstream(branch_op)
+ branch_2.set_upstream(branch_op)
+ branch_3.set_upstream(branch_op)
+
+ dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
@@ -221,10 +238,9 @@ class TestBranchOperator:
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
)
- self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
- tis = dagrun.get_task_instances()
- for ti in tis:
+ for ti in dag_maker.session.query(TI).filter(TI.dag_id == dag_id,
TI.execution_date == DEFAULT_DATE):
if ti.task_id == "make_choice":
assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":