This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new bb19b91  Remove side effects from tests (#9675)
bb19b91 is described below

commit bb19b9179ac645156a59054a34f147ca6b146b9f
Author: Kaxil Naik <[email protected]>
AuthorDate: Sun Jul 5 22:56:15 2020 +0100

    Remove side effects from tests (#9675)
    
    Add setUp and tearDown methods to clear tabels
---
 .../endpoints/test_connection_endpoint.py          |  5 ++--
 .../endpoints/test_event_log_endpoint.py           |  9 +++---
 .../api_connexion/endpoints/test_xcom_endpoint.py  | 16 ++++++-----
 tests/jobs/test_backfill_job.py                    |  8 +++++-
 tests/jobs/test_local_task_job.py                  |  8 +++++-
 tests/jobs/test_scheduler_job.py                   | 17 +++++++-----
 tests/models/test_cleartasks.py                    |  7 +++--
 tests/models/test_dagbag.py                        |  7 +++++
 tests/models/test_taskinstance.py                  | 32 ++++++++++------------
 tests/models/test_variable.py                      |  3 ++
 tests/models/test_xcom.py                          |  8 ++++++
 tests/sensors/test_base_sensor.py                  | 19 ++++++++-----
 tests/sensors/test_weekday_sensor.py               | 18 ++++++------
 tests/test_utils/db.py                             | 25 +++++++++++++++--
 14 files changed, 121 insertions(+), 61 deletions(-)

diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py 
b/tests/api_connexion/endpoints/test_connection_endpoint.py
index 12c1c05..4c0af2f 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -19,7 +19,7 @@ import unittest
 from parameterized import parameterized
 
 from airflow.models import Connection
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import provide_session
 from airflow.www import app
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_connections
@@ -34,8 +34,7 @@ class TestConnectionEndpoint(unittest.TestCase):
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
         # we want only the connection created here for this test
-        with create_session() as session:
-            session.query(Connection).delete()
+        clear_db_connections()
 
     def tearDown(self) -> None:
         clear_db_connections()
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py 
b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index fd6a8bd..e21b4b7 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -22,9 +22,10 @@ from airflow import DAG
 from airflow.models import Log, TaskInstance
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.utils import timezone
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import provide_session
 from airflow.www import app
 from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_logs
 
 
 class TestEventLogEndpoint(unittest.TestCase):
@@ -35,14 +36,12 @@ class TestEventLogEndpoint(unittest.TestCase):
 
     def setUp(self) -> None:
         self.client = self.app.test_client()  # type:ignore
-        with create_session() as session:
-            session.query(Log).delete()
+        clear_db_logs()
         self.default_time = "2020-06-10T20:00:00+00:00"
         self.default_time_2 = '2020-06-11T07:00:00+00:00'
 
     def tearDown(self) -> None:
-        with create_session() as session:
-            session.query(Log).delete()
+        clear_db_logs()
 
     def _create_task_instance(self):
         dag = DAG('TEST_DAG_ID', start_date=timezone.parse(self.default_time),
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py 
b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index 3f36e60..ad29cc2 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -21,9 +21,10 @@ from parameterized import parameterized
 
 from airflow.models import DagRun as DR, XCom
 from airflow.utils.dates import parse_execution_date
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import provide_session
 from airflow.utils.types import DagRunType
 from airflow.www import app
+from tests.test_utils.db import clear_db_runs, clear_db_xcom
 
 
 class TestXComEndpoint(unittest.TestCase):
@@ -32,23 +33,24 @@ class TestXComEndpoint(unittest.TestCase):
         super().setUpClass()
         cls.app = app.create_app(testing=True)  # type:ignore
 
+    @staticmethod
+    def clean_db():
+        clear_db_runs()
+        clear_db_xcom()
+
     def setUp(self) -> None:
         """
         Setup For XCom endpoint TC
         """
         self.client = self.app.test_client()  # type:ignore
         # clear existing xcoms
-        with create_session() as session:
-            session.query(XCom).delete()
-            session.query(DR).delete()
+        self.clean_db()
 
     def tearDown(self) -> None:
         """
         Clear Hanging XComs
         """
-        with create_session() as session:
-            session.query(XCom).delete()
-            session.query(DR).delete()
+        self.clean_db()
 
 
 class TestDeleteXComEntry(TestXComEndpoint):
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index c4823f1..d9a7dcf 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -82,12 +82,18 @@ class TestBackfillJob(unittest.TestCase):
     def setUpClass(cls):
         cls.dagbag = DagBag(include_examples=True)
 
-    def setUp(self):
+    @staticmethod
+    def clean_db():
         clear_db_runs()
         clear_db_pools()
 
+    def setUp(self):
+        self.clean_db()
         self.parser = cli_parser.get_parser()
 
+    def tearDown(self) -> None:
+        self.clean_db()
+
     def test_unfinished_dag_runs_set_to_failed(self):
         dag = self._get_dummy_dag('dummy_dag')
 
diff --git a/tests/jobs/test_local_task_job.py 
b/tests/jobs/test_local_task_job.py
index 2d91efe..1a8b4f9 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -41,7 +41,7 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from tests.test_utils.asserts import assert_queries_count
-from tests.test_utils.db import clear_db_runs
+from tests.test_utils.db import clear_db_jobs, clear_db_runs
 from tests.test_utils.mock_executor import MockExecutor
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -50,11 +50,16 @@ TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
 
 class TestLocalTaskJob(unittest.TestCase):
     def setUp(self):
+        clear_db_jobs()
         clear_db_runs()
         patcher = patch('airflow.jobs.base_job.sleep')
         self.addCleanup(patcher.stop)
         self.mock_base_job_sleep = patcher.start()
 
+    def tearDown(self) -> None:
+        clear_db_jobs()
+        clear_db_runs()
+
     def test_localtaskjob_essential_attr(self):
         """
         Check whether essential attributes
@@ -414,6 +419,7 @@ class TestLocalTaskJob(unittest.TestCase):
 @pytest.fixture()
 def clean_db_helper():
     yield
+    clear_db_jobs()
     clear_db_runs()
 
 
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 6ba94b7..e78f07d 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -55,7 +55,8 @@ from airflow.utils.types import DagRunType
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars, env_vars
 from tests.test_utils.db import (
-    clear_db_dags, clear_db_errors, clear_db_pools, clear_db_runs, 
clear_db_sla_miss, set_default_pool_slots,
+    clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, 
clear_db_runs, clear_db_sla_miss,
+    set_default_pool_slots,
 )
 from tests.test_utils.mock_executor import MockExecutor
 
@@ -88,23 +89,25 @@ def disable_load_example():
 
 @pytest.mark.usefixtures("disable_load_example")
 class TestDagFileProcessor(unittest.TestCase):
-    def setUp(self):
+
+    @staticmethod
+    def clean_db():
         clear_db_runs()
         clear_db_pools()
         clear_db_dags()
         clear_db_sla_miss()
         clear_db_errors()
+        clear_db_jobs()
+
+    def setUp(self):
+        self.clean_db()
 
         # Speed up some tests by not running the tasks, just look at what we
         # enqueue!
         self.null_exec = MockExecutor()
 
     def tearDown(self) -> None:
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_dags()
-        clear_db_sla_miss()
-        clear_db_errors()
+        self.clean_db()
 
     def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + 
timedelta(hours=1), **kwargs):
         dag = DAG(
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index e59e862..bd6244d 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -25,13 +25,16 @@ from airflow.operators.dummy_operator import DummyOperator
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from tests.models import DEFAULT_DATE
+from tests.test_utils import db
 
 
 class TestClearTasks(unittest.TestCase):
 
+    def setUp(self) -> None:
+        db.clear_db_runs()
+
     def tearDown(self):
-        with create_session() as session:
-            session.query(TI).delete()
+        db.clear_db_runs()
 
     def test_clear_task_instances(self):
         dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE,
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 4abd446..15ae731 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -31,6 +31,7 @@ from airflow.models import DagBag, DagModel
 from airflow.utils.session import create_session
 from tests.models import TEST_DAGS_FOLDER
 from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_dags
 
 
 class TestDagBag(unittest.TestCase):
@@ -42,6 +43,12 @@ class TestDagBag(unittest.TestCase):
     def tearDownClass(cls):
         shutil.rmtree(cls.empty_dir)
 
+    def setUp(self) -> None:
+        clear_db_dags()
+
+    def tearDown(self) -> None:
+        clear_db_dags()
+
     def test_get_existing_dag(self):
         """
         Test that we're able to parse some example DAGs and retrieve them
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 6aaf9ef..6cd8330 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -33,7 +33,7 @@ from sqlalchemy.orm.session import Session
 from airflow import models, settings
 from airflow.exceptions import AirflowException, AirflowFailException, 
AirflowSkipException
 from airflow.models import (
-    DAG, DagRun, Pool, RenderedTaskInstanceFields, TaskFail, TaskInstance as 
TI, TaskReschedule, Variable,
+    DAG, DagRun, Pool, RenderedTaskInstanceFields, TaskInstance as TI, 
TaskReschedule, Variable,
 )
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy_operator import DummyOperator
@@ -53,9 +53,6 @@ from tests.models import DEFAULT_DATE
 from tests.test_utils import db
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import (
-    clear_db_dags, clear_db_errors, clear_db_pools, clear_db_runs, 
clear_db_sla_miss,
-)
 
 
 class CallbackWrapper:
@@ -84,22 +81,23 @@ class CallbackWrapper:
 
 class TestTaskInstance(unittest.TestCase):
 
-    def setUp(self):
+    @staticmethod
+    def clean_db():
         db.clear_db_pools()
+        db.clear_db_runs()
+        db.clear_db_task_fail()
         db.clear_rendered_ti_fields()
+        db.clear_db_task_reschedule()
+
+    def setUp(self):
+        self.clean_db()
         with create_session() as session:
             test_pool = Pool(pool='test_pool', slots=1)
             session.add(test_pool)
             session.commit()
 
     def tearDown(self):
-        db.clear_db_pools()
-        db.clear_rendered_ti_fields()
-        with create_session() as session:
-            session.query(TaskFail).delete()
-            session.query(TaskReschedule).delete()
-            session.query(models.TaskInstance).delete()
-            session.query(models.DagRun).delete()
+        self.clean_db()
 
     def test_set_task_dates(self):
         """
@@ -1713,11 +1711,11 @@ class TestRunRawTaskQueriesCount(unittest.TestCase):
 
     @staticmethod
     def _clean():
-        clear_db_runs()
-        clear_db_pools()
-        clear_db_dags()
-        clear_db_sla_miss()
-        clear_db_errors()
+        db.clear_db_runs()
+        db.clear_db_pools()
+        db.clear_db_dags()
+        db.clear_db_sla_miss()
+        db.clear_db_errors()
 
     def setUp(self) -> None:
         self._clean()
diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py
index bf2de1b..4b507e6 100644
--- a/tests/models/test_variable.py
+++ b/tests/models/test_variable.py
@@ -23,15 +23,18 @@ from parameterized import parameterized
 
 from airflow import settings
 from airflow.models import Variable, crypto
+from tests.test_utils import db
 from tests.test_utils.config import conf_vars
 
 
 class TestVariable(unittest.TestCase):
     def setUp(self):
         crypto._fernet = None
+        db.clear_db_variables()
 
     def tearDown(self):
         crypto._fernet = None
+        db.clear_db_variables()
 
     @conf_vars({('core', 'fernet_key'): ''})
     def test_variable_no_encryption(self):
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index cb9cf0a..39586d1 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -21,6 +21,7 @@ from airflow import settings
 from airflow.configuration import conf
 from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
 from airflow.utils import timezone
+from tests.test_utils import db
 from tests.test_utils.config import conf_vars
 
 
@@ -31,6 +32,13 @@ class CustomXCom(BaseXCom):
 
 
 class TestXCom(unittest.TestCase):
+
+    def setUp(self) -> None:
+        db.clear_db_xcom()
+
+    def tearDown(self) -> None:
+        db.clear_db_xcom()
+
     @conf_vars({("core", "xcom_backend"): "tests.models.test_xcom.CustomXCom"})
     def test_resolve_xcom_class(self):
         cls = resolve_xcom_backend()
diff --git a/tests/sensors/test_base_sensor.py 
b/tests/sensors/test_base_sensor.py
index f11b7e9..a91412c 100644
--- a/tests/sensors/test_base_sensor.py
+++ b/tests/sensors/test_base_sensor.py
@@ -24,8 +24,8 @@ from unittest.mock import Mock, patch
 from freezegun import freeze_time
 
 from airflow.exceptions import AirflowException, AirflowRescheduleException, 
AirflowSensorTimeout
-from airflow.models import DagBag, DagRun, TaskInstance, TaskReschedule
-from airflow.models.dag import DAG, settings
+from airflow.models import DagBag, TaskInstance, TaskReschedule
+from airflow.models.dag import DAG
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator, 
poke_mode_only
 from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
@@ -33,6 +33,7 @@ from airflow.utils import timezone
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
+from tests.test_utils import db
 
 DEFAULT_DATE = datetime(2015, 1, 1)
 TEST_DAG_ID = 'unit_test_dag'
@@ -51,18 +52,22 @@ class DummySensor(BaseSensorOperator):
 
 
 class TestBaseSensor(unittest.TestCase):
+    @staticmethod
+    def clean_db():
+        db.clear_db_runs()
+        db.clear_db_task_reschedule()
+        db.clear_db_xcom()
+
     def setUp(self):
         args = {
             'owner': 'airflow',
             'start_date': DEFAULT_DATE
         }
         self.dag = DAG(TEST_DAG_ID, default_args=args)
+        self.clean_db()
 
-        session = settings.Session()
-        session.query(TaskReschedule).delete()
-        session.query(DagRun).delete()
-        session.query(TaskInstance).delete()
-        session.commit()
+    def tearDown(self) -> None:
+        self.clean_db()
 
     def _make_dag_run(self):
         return self.dag.create_dagrun(
diff --git a/tests/sensors/test_weekday_sensor.py 
b/tests/sensors/test_weekday_sensor.py
index 00bd8a1..0ff4f19 100644
--- a/tests/sensors/test_weekday_sensor.py
+++ b/tests/sensors/test_weekday_sensor.py
@@ -22,12 +22,12 @@ import unittest
 from parameterized import parameterized
 
 from airflow.exceptions import AirflowSensorTimeout
-from airflow.models import DagBag, TaskFail, TaskInstance
+from airflow.models import DagBag
 from airflow.models.dag import DAG
 from airflow.sensors.weekday_sensor import DayOfWeekSensor
-from airflow.settings import Session
 from airflow.utils.timezone import datetime
 from airflow.utils.weekday import WeekDay
+from tests.test_utils import db
 
 DEFAULT_DATE = datetime(2018, 12, 10)
 WEEKDAY_DATE = datetime(2018, 12, 20)
@@ -38,7 +38,13 @@ DEV_NULL = '/dev/null'
 
 class TestDayOfWeekSensor(unittest.TestCase):
 
+    @staticmethod
+    def clean_db():
+        db.clear_db_runs()
+        db.clear_db_task_fail()
+
     def setUp(self):
+        self.clean_db()
         self.dagbag = DagBag(
             dag_folder=DEV_NULL,
             include_examples=True
@@ -51,13 +57,7 @@ class TestDayOfWeekSensor(unittest.TestCase):
         self.dag = dag
 
     def tearDown(self):
-        session = Session()
-        session.query(TaskInstance).filter_by(
-            dag_id=TEST_DAG_ID).delete()
-        session.query(TaskFail).filter_by(
-            dag_id=TEST_DAG_ID).delete()
-        session.commit()
-        session.close()
+        self.clean_db()
 
     @parameterized.expand([
         ("with-string", 'Thursday'),
diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py
index 6c2c297..44477cf 100644
--- a/tests/test_utils/db.py
+++ b/tests/test_utils/db.py
@@ -15,9 +15,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from airflow.jobs.base_job import BaseJob
 from airflow.models import (
-    Connection, DagModel, DagRun, DagTag, Pool, RenderedTaskInstanceFields, 
SlaMiss, TaskInstance, Variable,
-    XCom, errors,
+    Connection, DagModel, DagRun, DagTag, Log, Pool, 
RenderedTaskInstanceFields, SlaMiss, TaskFail,
+    TaskInstance, TaskReschedule, Variable, XCom, errors,
 )
 from airflow.models.dagcode import DagCode
 from airflow.models.serialized_dag import SerializedDagModel
@@ -93,3 +94,23 @@ def clear_db_import_errors():
 def clear_db_xcom():
     with create_session() as session:
         session.query(XCom).delete()
+
+
+def clear_db_logs():
+    with create_session() as session:
+        session.query(Log).delete()
+
+
+def clear_db_jobs():
+    with create_session() as session:
+        session.query(BaseJob).delete()
+
+
+def clear_db_task_fail():
+    with create_session() as session:
+        session.query(TaskFail).delete()
+
+
+def clear_db_task_reschedule():
+    with create_session() as session:
+        session.query(TaskReschedule).delete()

Reply via email to