This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1834ce  Use dag_maker fixture in test_processor.py (#17506)
a1834ce is described below

commit a1834ce873d28c373bc11d0637bda57c17c79189
Author: Ephraim Anierobi <splendidzig...@gmail.com>
AuthorDate: Wed Aug 11 23:09:48 2021 +0100

    Use dag_maker fixture in test_processor.py (#17506)
    
    This change applies dag_maker fixture in test_process.py
    
    fixup! Use dag_maker fixture in test_processor.py
    
    fixup! fixup! Use dag_maker fixture in test_processor.py
---
 tests/conftest.py                      |  82 ++++++++++++++++
 tests/dag_processing/test_processor.py |  98 +++++++------------
 tests/models/test_taskinstance.py      | 173 +++++++++++++--------------------
 3 files changed, 187 insertions(+), 166 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 197a573..c7685d4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -428,6 +428,30 @@ def app():
 
 @pytest.fixture
 def dag_maker(request):
+    """
+    The dag_maker helps us to create DAG & DagModel automatically.
+
+    You have to use the dag_maker as a context manager and it takes
+    the same argument as DAG::
+
+        with dag_maker(dag_id="mydag") as dag:
+            task1 = DummyOperator(task_id='mytask')
+            task2 = DummyOperator(task_id='mytask2')
+
+    If the DagModel you want to use needs different parameters than the one
+    automatically created by the dag_maker, you have to update the DagModel as 
below::
+
+        dag_maker.dag_model.is_active = False
+        session.merge(dag_maker.dag_model)
+        session.commit()
+
+    For any test you use the dag_maker, make sure to create a DagRun::
+
+        dag_maker.create_dagrun()
+
+    The dag_maker.create_dagrun takes the same arguments as dag.create_dagrun
+
+    """
     from airflow.models import DAG, DagModel
     from airflow.utils import timezone
     from airflow.utils.session import provide_session
@@ -473,7 +497,12 @@ def dag_maker(request):
             self.kwargs = kwargs
             self.session = session
             self.start_date = self.kwargs.get('start_date', None)
+            default_args = kwargs.get('default_args', None)
+            if default_args and not self.start_date:
+                if 'start_date' in default_args:
+                    self.start_date = default_args.get('start_date')
             if not self.start_date:
+
                 if hasattr(request.module, 'DEFAULT_DATE'):
                     self.start_date = getattr(request.module, 'DEFAULT_DATE')
                 else:
@@ -484,3 +513,56 @@ def dag_maker(request):
             return self
 
     return DagFactory()
+
+
+@pytest.fixture
+def create_dummy_dag(dag_maker):
+    """
+    This fixture creates a `DAG` with a single `DummyOperator` task.
+    DagRun and DagModel is also created.
+
+    Apart from the already existing arguments, any other argument in kwargs
+    is passed to the DAG and not to the DummyOperator task.
+
+    If you have an argument that you want to pass to the DummyOperator that
+    is not here, please use `default_args` so that the DAG will pass it to the
+    Task::
+
+        dag, task = 
create_dummy_dag(default_args={'start_date':timezone.datetime(2016, 1, 1)})
+
+    You cannot be able to alter the created DagRun or DagModel, use 
`dag_maker` fixture instead.
+    """
+    from airflow.operators.dummy import DummyOperator
+    from airflow.utils.types import DagRunType
+
+    def create_dag(
+        dag_id='dag',
+        task_id='op1',
+        task_concurrency=16,
+        pool='default_pool',
+        executor_config={},
+        trigger_rule='all_done',
+        on_success_callback=None,
+        on_execute_callback=None,
+        on_failure_callback=None,
+        on_retry_callback=None,
+        email=None,
+        **kwargs,
+    ):
+        with dag_maker(dag_id, **kwargs) as dag:
+            op = DummyOperator(
+                task_id=task_id,
+                task_concurrency=task_concurrency,
+                executor_config=executor_config,
+                on_success_callback=on_success_callback,
+                on_execute_callback=on_execute_callback,
+                on_failure_callback=on_failure_callback,
+                on_retry_callback=on_retry_callback,
+                email=email,
+                pool=pool,
+                trigger_rule=trigger_rule,
+            )
+        dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        return dag, op
+
+    return create_dag
diff --git a/tests/dag_processing/test_processor.py 
b/tests/dag_processing/test_processor.py
index fcfd0f6..d99a8cd 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -19,7 +19,6 @@
 
 import datetime
 import os
-from datetime import timedelta
 from tempfile import NamedTemporaryFile
 from unittest import mock
 from unittest.mock import MagicMock, patch
@@ -30,7 +29,7 @@ import pytest
 from airflow import settings
 from airflow.configuration import conf
 from airflow.dag_processing.processor import DagFileProcessor
-from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance, errors
+from airflow.models import DagBag, SlaMiss, TaskInstance, errors
 from airflow.models.taskinstance import SimpleTaskInstance
 from airflow.operators.dummy import DummyOperator
 from airflow.utils import timezone
@@ -97,34 +96,12 @@ class TestDagFileProcessor:
             self.scheduler_job = None
         self.clean_db()
 
-    def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + 
timedelta(hours=1), **kwargs):
-        dag = DAG(
-            dag_id='test_scheduler_reschedule',
-            start_date=start_date,
-            # Make sure it only creates a single DAG Run
-            end_date=end_date,
-        )
-        dag.clear()
-        dag.is_subdag = False
-        with create_session() as session:
-            orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False)
-            session.merge(orm_dag)
-            session.commit()
-        return dag
-
-    @classmethod
-    def setup_class(cls):
-        # Ensure the DAGs we are looking at from the DB are up-to-date
-        non_serialized_dagbag = DagBag(read_dags_from_db=False, 
include_examples=False)
-        non_serialized_dagbag.sync_to_db()
-        cls.dagbag = DagBag(read_dags_from_db=True)
-
     def _process_file(self, file_path, session):
         dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
 
         dag_file_processor.process_file(file_path, [], False, session)
 
-    def test_dag_file_processor_sla_miss_callback(self):
+    def test_dag_file_processor_sla_miss_callback(self, create_dummy_dag):
         """
         Test that the dag file processor calls the sla miss callback
         """
@@ -135,14 +112,13 @@ class TestDagFileProcessor:
         # Create dag with a start of 1 day ago, but an sla of 0
         # so we'll already have an sla_miss on the books.
         test_start_date = days_ago(1)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
+            task_id='dummy',
             sla_miss_callback=sla_callback,
             default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta()},
         )
 
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='success'))
 
         session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', 
execution_date=test_start_date))
@@ -152,7 +128,7 @@ class TestDagFileProcessor:
 
         assert sla_callback.called
 
-    def test_dag_file_processor_sla_miss_callback_invalid_sla(self):
+    def test_dag_file_processor_sla_miss_callback_invalid_sla(self, 
create_dummy_dag):
         """
         Test that the dag file processor does not call the sla miss callback 
when
         given an invalid sla
@@ -165,14 +141,13 @@ class TestDagFileProcessor:
         # so we'll already have an sla_miss on the books.
         # Pass anything besides a timedelta object to the sla argument.
         test_start_date = days_ago(1)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
+            task_id='dummy',
             sla_miss_callback=sla_callback,
             default_args={'start_date': test_start_date, 'sla': None},
         )
 
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='success'))
 
         session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', 
execution_date=test_start_date))
@@ -181,7 +156,7 @@ class TestDagFileProcessor:
         dag_file_processor.manage_slas(dag=dag, session=session)
         sla_callback.assert_not_called()
 
-    def test_dag_file_processor_sla_miss_callback_sent_notification(self):
+    def test_dag_file_processor_sla_miss_callback_sent_notification(self, 
create_dummy_dag):
         """
         Test that the dag file processor does not call the sla_miss_callback 
when a
         notification has already been sent
@@ -194,14 +169,13 @@ class TestDagFileProcessor:
         # Create dag with a start of 2 days ago, but an sla of 1 day
         # ago so we'll already have an sla_miss on the books
         test_start_date = days_ago(2)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
+            task_id='dummy',
             sla_miss_callback=sla_callback,
             default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(days=1)},
         )
 
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-
         # Create a TaskInstance for two days ago
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='success'))
 
@@ -222,7 +196,7 @@ class TestDagFileProcessor:
 
         sla_callback.assert_not_called()
 
-    def test_dag_file_processor_sla_miss_callback_exception(self):
+    def test_dag_file_processor_sla_miss_callback_exception(self, 
create_dummy_dag):
         """
         Test that the dag file processor gracefully logs an exception if there 
is a problem
         calling the sla_miss_callback
@@ -232,14 +206,13 @@ class TestDagFileProcessor:
         sla_callback = MagicMock(side_effect=RuntimeError('Could not call 
function'))
 
         test_start_date = days_ago(2)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
+            task_id='dummy',
             sla_miss_callback=sla_callback,
-            default_args={'start_date': test_start_date},
+            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(hours=1)},
         )
 
-        task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', 
sla=datetime.timedelta(hours=1))
-
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='Success'))
 
         # Create an SlaMiss where notification was sent, but email was not
@@ -255,18 +228,18 @@ class TestDagFileProcessor:
         )
 
     @mock.patch('airflow.dag_processing.processor.send_email')
-    def 
test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, 
mock_send_email):
+    def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(
+        self, mock_send_email, create_dummy_dag
+    ):
         session = settings.Session()
 
         test_start_date = days_ago(2)
-        dag = DAG(
-            dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(days=1)},
-        )
-
         email1 = 'te...@test.com'
-        task = DummyOperator(
-            task_id='sla_missed', dag=dag, owner='airflow', email=email1, 
sla=datetime.timedelta(hours=1)
+        dag, task = create_dummy_dag(
+            dag_id='test_sla_miss',
+            task_id='sla_missed',
+            email=email1,
+            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(hours=1)},
         )
 
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='Success'))
@@ -288,7 +261,9 @@ class TestDagFileProcessor:
 
     @mock.patch('airflow.dag_processing.processor.Stats.incr')
     @mock.patch("airflow.utils.email.send_email")
-    def test_dag_file_processor_sla_miss_email_exception(self, 
mock_send_email, mock_stats_incr):
+    def test_dag_file_processor_sla_miss_email_exception(
+        self, mock_send_email, mock_stats_incr, create_dummy_dag
+    ):
         """
         Test that the dag file processor gracefully logs an exception if there 
is a problem
         sending an email
@@ -299,14 +274,13 @@ class TestDagFileProcessor:
         mock_send_email.side_effect = RuntimeError('Could not send an email')
 
         test_start_date = days_ago(2)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(days=1)},
-        )
-
-        task = DummyOperator(
-            task_id='dummy', dag=dag, owner='airflow', email='t...@test.com', 
sla=datetime.timedelta(hours=1)
+            task_id='dummy',
+            email='t...@test.com',
+            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(hours=1)},
         )
+        mock_stats_incr.reset_mock()
 
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='Success'))
 
@@ -322,7 +296,7 @@ class TestDagFileProcessor:
         )
         
mock_stats_incr.assert_called_once_with('sla_email_notification_failure')
 
-    def test_dag_file_processor_sla_miss_deleted_task(self):
+    def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag):
         """
         Test that the dag file processor will not crash when trying to send
         sla miss notification for a deleted task
@@ -330,13 +304,11 @@ class TestDagFileProcessor:
         session = settings.Session()
 
         test_start_date = days_ago(2)
-        dag = DAG(
+        dag, task = create_dummy_dag(
             dag_id='test_sla_miss',
-            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(days=1)},
-        )
-
-        task = DummyOperator(
-            task_id='dummy', dag=dag, owner='airflow', email='t...@test.com', 
sla=datetime.timedelta(hours=1)
+            task_id='dummy',
+            email='t...@test.com',
+            default_args={'start_date': test_start_date, 'sla': 
datetime.timedelta(hours=1)},
         )
 
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state='Success'))
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index d60d1b7..93f6138 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -100,39 +100,6 @@ class CallbackWrapper:
         self.task_state_in_callback = temp_instance.state
 
 
-@pytest.fixture
-def get_dummy_dag(dag_maker):
-    def create_dag(
-        dag_id='dag',
-        task_id='op1',
-        task_concurrency=16,
-        pool='default_pool',
-        executor_config={},
-        trigger_rule='all_done',
-        on_success_callback=None,
-        on_execute_callback=None,
-        on_failure_callback=None,
-        on_retry_callback=None,
-        **kwargs,
-    ):
-        with dag_maker(dag_id, **kwargs) as dag:
-            op = DummyOperator(
-                task_id=task_id,
-                task_concurrency=task_concurrency,
-                executor_config=executor_config,
-                on_success_callback=on_success_callback,
-                on_execute_callback=on_execute_callback,
-                on_failure_callback=on_failure_callback,
-                on_retry_callback=on_retry_callback,
-                pool=pool,
-                trigger_rule=trigger_rule,
-            )
-        dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
-        return dag, op
-
-    return create_dag
-
-
 class TestTaskInstance:
     @staticmethod
     def clean_db():
@@ -265,13 +232,13 @@ class TestTaskInstance:
         assert op.dag is dag
         assert op in dag.tasks
 
-    def test_infer_dag(self, get_dummy_dag):
+    def test_infer_dag(self, create_dummy_dag):
         op1 = DummyOperator(task_id='test_op_1')
         op2 = DummyOperator(task_id='test_op_2')
 
-        dag, op3 = get_dummy_dag(task_id='test_op_3')
+        dag, op3 = create_dummy_dag(task_id='test_op_3')
 
-        _, op4 = get_dummy_dag('dag2', task_id='test_op_4')
+        _, op4 = create_dummy_dag('dag2', task_id='test_op_4')
 
         # double check dags
         assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, 
True, True]
@@ -304,10 +271,10 @@ class TestTaskInstance:
         assert op2 in op3.downstream_list
 
     @patch.object(DAG, 'get_concurrency_reached')
-    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, 
get_dummy_dag):
+    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, 
create_dummy_dag):
         mock_concurrency_reached.return_value = True
 
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             dag_id='test_requeue_over_dag_concurrency',
             task_id='test_requeue_over_dag_concurrency_op',
             max_active_runs=1,
@@ -322,8 +289,8 @@ class TestTaskInstance:
         ti.run()
         assert ti.state == State.NONE
 
-    def test_requeue_over_task_concurrency(self, get_dummy_dag):
-        _, task = get_dummy_dag(
+    def test_requeue_over_task_concurrency(self, create_dummy_dag):
+        _, task = create_dummy_dag(
             dag_id='test_requeue_over_task_concurrency',
             task_id='test_requeue_over_task_concurrency_op',
             task_concurrency=0,
@@ -339,8 +306,8 @@ class TestTaskInstance:
         ti.run()
         assert ti.state == State.NONE
 
-    def test_requeue_over_pool_concurrency(self, get_dummy_dag):
-        _, task = get_dummy_dag(
+    def test_requeue_over_pool_concurrency(self, create_dummy_dag):
+        _, task = create_dummy_dag(
             dag_id='test_requeue_over_pool_concurrency',
             task_id='test_requeue_over_pool_concurrency_op',
             task_concurrency=0,
@@ -391,13 +358,13 @@ class TestTaskInstance:
         for (dep_patch, method_patch) in patch_dict.values():
             dep_patch.stop()
 
-    def test_mark_non_runnable_task_as_success(self, get_dummy_dag):
+    def test_mark_non_runnable_task_as_success(self, create_dummy_dag):
         """
         test that running task with mark_success param update task state
         as SUCCESS without running task despite it fails dependency checks.
         """
         non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - 
set(State.SUCCESS)).pop()
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             dag_id='test_mark_non_runnable_task_as_success',
             task_id='test_mark_non_runnable_task_as_success_op',
         )
@@ -409,11 +376,11 @@ class TestTaskInstance:
         ti.run(mark_success=True)
         assert ti.state == State.SUCCESS
 
-    def test_run_pooling_task(self, get_dummy_dag):
+    def test_run_pooling_task(self, create_dummy_dag):
         """
         test that running a task in an existing pool update task state as 
SUCCESS.
         """
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             dag_id='test_run_pooling_task',
             task_id='test_run_pooling_task_op',
             pool='test_pool',
@@ -444,11 +411,11 @@ class TestTaskInstance:
             create_task_instance()
 
     @provide_session
-    def test_ti_updates_with_task(self, get_dummy_dag, session=None):
+    def test_ti_updates_with_task(self, create_dummy_dag, session=None):
         """
         test that updating the executor_config propagates to the TaskInstance 
DB
         """
-        dag, task = get_dummy_dag(
+        dag, task = create_dummy_dag(
             dag_id='test_run_pooling_task',
             task_id='test_run_pooling_task_op',
             executor_config={'foo': 'bar'},
@@ -472,13 +439,13 @@ class TestTaskInstance:
         assert {'bar': 'baz'} == tis[1].executor_config
         session.rollback()
 
-    def test_run_pooling_task_with_mark_success(self, get_dummy_dag):
+    def test_run_pooling_task_with_mark_success(self, create_dummy_dag):
         """
         test that running task in an existing pool with mark_success param
         update task state as SUCCESS without running task
         despite it fails dependency checks.
         """
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             dag_id='test_run_pooling_task_with_mark_success',
             task_id='test_run_pooling_task_with_mark_success_op',
         )
@@ -944,9 +911,9 @@ class TestTaskInstance:
         flag_upstream_failed: bool,
         expect_state: State,
         expect_completed: bool,
-        get_dummy_dag,
+        create_dummy_dag,
     ):
-        dag, downstream = get_dummy_dag('test-dag', task_id='downstream', 
trigger_rule=trigger_rule)
+        dag, downstream = create_dummy_dag('test-dag', task_id='downstream', 
trigger_rule=trigger_rule)
         for i in range(5):
             task = DummyOperator(task_id=f'runme_{i}', dag=dag)
             task.set_downstream(downstream)
@@ -967,8 +934,8 @@ class TestTaskInstance:
         assert completed == expect_completed
         assert ti.state == expect_state
 
-    def test_respects_prev_dagrun_dep(self, get_dummy_dag):
-        _, task = get_dummy_dag(dag_id='test_dag')
+    def test_respects_prev_dagrun_dep(self, create_dummy_dag):
+        _, task = create_dummy_dag(dag_id='test_dag')
         ti = TI(task, DEFAULT_DATE)
         failing_status = [TIDepStatus('test fail status name', False, 'test 
fail reason')]
         passing_status = [TIDepStatus('test pass status name', True, 'test 
passing reason')]
@@ -991,8 +958,8 @@ class TestTaskInstance:
             (State.NONE, False),
         ],
     )
-    def test_are_dependents_done(self, downstream_ti_state, 
expected_are_dependents_done, get_dummy_dag):
-        dag, task = get_dummy_dag()
+    def test_are_dependents_done(self, downstream_ti_state, 
expected_are_dependents_done, create_dummy_dag):
+        dag, task = create_dummy_dag()
         downstream_task = DummyOperator(task_id='downstream_task', dag=dag)
         task >> downstream_task
 
@@ -1002,11 +969,11 @@ class TestTaskInstance:
         downstream_ti.set_state(downstream_ti_state)
         assert ti.are_dependents_done() == expected_are_dependents_done
 
-    def test_xcom_pull(self, get_dummy_dag):
+    def test_xcom_pull(self, create_dummy_dag):
         """
         Test xcom_pull, using different filtering methods.
         """
-        dag, task1 = get_dummy_dag(
+        dag, task1 = create_dummy_dag(
             dag_id='test_xcom',
             task_id='test_xcom_1',
             schedule_interval='@monthly',
@@ -1040,14 +1007,14 @@ class TestTaskInstance:
         result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], 
key='foo')
         assert result == ['bar', 'baz']
 
-    def test_xcom_pull_after_success(self, get_dummy_dag):
+    def test_xcom_pull_after_success(self, create_dummy_dag):
         """
         tests xcom set/clear relative to a task in a 'success' rerun scenario
         """
         key = 'xcom_key'
         value = 'xcom_value'
 
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             dag_id='test_xcom',
             schedule_interval='@monthly',
             task_id='test_xcom',
@@ -1072,7 +1039,7 @@ class TestTaskInstance:
         ti.run(ignore_all_deps=True)
         assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
 
-    def test_xcom_pull_different_execution_date(self, get_dummy_dag):
+    def test_xcom_pull_different_execution_date(self, create_dummy_dag):
         """
         tests xcom fetch behavior with different execution dates, using
         both xcom_pull with "include_prior_dates" and without
@@ -1080,7 +1047,7 @@ class TestTaskInstance:
         key = 'xcom_key'
         value = 'xcom_value'
 
-        dag, task = get_dummy_dag(
+        dag, task = create_dummy_dag(
             dag_id='test_xcom',
             schedule_interval='@monthly',
             task_id='test_xcom',
@@ -1146,8 +1113,8 @@ class TestTaskInstance:
         with pytest.raises(TestError):
             ti.run()
 
-    def test_check_and_change_state_before_execution(self, get_dummy_dag):
-        _, task = 
get_dummy_dag(dag_id='test_check_and_change_state_before_execution')
+    def test_check_and_change_state_before_execution(self, create_dummy_dag):
+        _, task = 
create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         assert ti._try_number == 0
         assert ti.check_and_change_state_before_execution()
@@ -1155,18 +1122,18 @@ class TestTaskInstance:
         assert ti.state == State.RUNNING
         assert ti._try_number == 1
 
-    def test_check_and_change_state_before_execution_dep_not_met(self, 
get_dummy_dag):
-        dag, task = 
get_dummy_dag(dag_id='test_check_and_change_state_before_execution')
+    def test_check_and_change_state_before_execution_dep_not_met(self, 
create_dummy_dag):
+        dag, task = 
create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
         task2 = DummyOperator(task_id='task2', dag=dag, 
start_date=DEFAULT_DATE)
         task >> task2
         ti = TI(task=task2, execution_date=timezone.utcnow())
         assert not ti.check_and_change_state_before_execution()
 
-    def test_try_number(self, get_dummy_dag):
+    def test_try_number(self, create_dummy_dag):
         """
         Test the try_number accessor behaves in various running states
         """
-        _, task = 
get_dummy_dag(dag_id='test_check_and_change_state_before_execution')
+        _, task = 
create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
         ti = TI(task=task, execution_date=timezone.utcnow())
         assert 1 == ti.try_number
         ti.try_number = 2
@@ -1175,11 +1142,11 @@ class TestTaskInstance:
         ti.state = State.SUCCESS
         assert 3 == ti.try_number
 
-    def test_get_num_running_task_instances(self, get_dummy_dag):
+    def test_get_num_running_task_instances(self, create_dummy_dag):
         session = settings.Session()
 
-        _, task = get_dummy_dag(dag_id='test_get_num_running_task_instances', 
task_id='task1')
-        _, task2 = 
get_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', 
task_id='task2')
+        _, task = 
create_dummy_dag(dag_id='test_get_num_running_task_instances', task_id='task1')
+        _, task2 = 
create_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', 
task_id='task2')
         ti1 = TI(task=task, execution_date=DEFAULT_DATE)
         ti2 = TI(task=task, execution_date=DEFAULT_DATE + 
datetime.timedelta(days=1))
         ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
@@ -1207,8 +1174,8 @@ class TestTaskInstance:
     #     self.assertEqual(d['task_id'][0], 'op')
     #     self.assertEqual(pendulum.parse(d['execution_date'][0]), now)
 
-    def test_log_url(self, get_dummy_dag):
-        _, task = get_dummy_dag('dag', task_id='op')
+    def test_log_url(self, create_dummy_dag):
+        _, task = create_dummy_dag('dag', task_id='op')
         ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1))
 
         expected_url = (
@@ -1219,9 +1186,9 @@ class TestTaskInstance:
         )
         assert ti.log_url == expected_url
 
-    def test_mark_success_url(self, get_dummy_dag):
+    def test_mark_success_url(self, create_dummy_dag):
         now = pendulum.now('Europe/Brussels')
-        _, task = get_dummy_dag('dag', task_id='op')
+        _, task = create_dummy_dag('dag', task_id='op')
         ti = TI(task=task, execution_date=now)
         query = urllib.parse.parse_qs(
             urllib.parse.urlparse(ti.mark_success_url).query, 
keep_blank_values=True, strict_parsing=True
@@ -1324,9 +1291,9 @@ class TestTaskInstance:
         ti.set_duration()
         assert ti.duration is None
 
-    def test_success_callback_no_race_condition(self, get_dummy_dag):
+    def test_success_callback_no_race_condition(self, create_dummy_dag):
         callback_wrapper = CallbackWrapper()
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             'test_success_callback_no_race_condition',
             on_success_callback=callback_wrapper.success_handler,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
@@ -1472,8 +1439,8 @@ class TestTaskInstance:
         assert ti_2.get_previous_start_date() == ti_1.start_date
         assert ti_1.start_date is None
 
-    def test_pendulum_template_dates(self, get_dummy_dag):
-        dag, task = get_dummy_dag(
+    def test_pendulum_template_dates(self, create_dummy_dag):
+        dag, task = create_dummy_dag(
             dag_id='test_pendulum_template_dates',
             task_id='test_pendulum_template_dates_task',
             schedule_interval='0 12 * * *',
@@ -1500,7 +1467,7 @@ class TestTaskInstance:
             ('{{ conn.a_connection.extra_dejson.extra__asana__workspace }}', 
'extra1'),
         ],
     )
-    def test_template_with_connection(self, content, expected_output, 
get_dummy_dag):
+    def test_template_with_connection(self, content, expected_output, 
create_dummy_dag):
         """
         Test the availability of variables in templates
         """
@@ -1522,7 +1489,7 @@ class TestTaskInstance:
                 session,
             )
 
-        _, task = get_dummy_dag()
+        _, task = create_dummy_dag()
 
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
@@ -1538,24 +1505,24 @@ class TestTaskInstance:
             ('{{ var.value.get("missing_variable", "fallback") }}', 
'fallback'),
         ],
     )
-    def test_template_with_variable(self, content, expected_output, 
get_dummy_dag):
+    def test_template_with_variable(self, content, expected_output, 
create_dummy_dag):
         """
         Test the availability of variables in templates
         """
         Variable.set('a_variable', 'a test value')
 
-        _, task = get_dummy_dag()
+        _, task = create_dummy_dag()
 
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
         result = task.render_template(content, context)
         assert result == expected_output
 
-    def test_template_with_variable_missing(self, get_dummy_dag):
+    def test_template_with_variable_missing(self, create_dummy_dag):
         """
         Test the availability of variables in templates
         """
-        _, task = get_dummy_dag()
+        _, task = create_dummy_dag()
 
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
@@ -1572,28 +1539,28 @@ class TestTaskInstance:
             ('{{ var.json.get("missing_variable", {"a": {"test": 
"fallback"}})["a"]["test"] }}', 'fallback'),
         ],
     )
-    def test_template_with_json_variable(self, content, expected_output, 
get_dummy_dag):
+    def test_template_with_json_variable(self, content, expected_output, 
create_dummy_dag):
         """
         Test the availability of variables in templates
         """
         Variable.set('a_variable', {'a': {'test': 'value'}}, 
serialize_json=True)
 
-        _, task = get_dummy_dag()
+        _, task = create_dummy_dag()
 
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
         result = task.render_template(content, context)
         assert result == expected_output
 
-    def test_template_with_json_variable_missing(self, get_dummy_dag):
-        _, task = get_dummy_dag()
+    def test_template_with_json_variable_missing(self, create_dummy_dag):
+        _, task = create_dummy_dag()
 
         ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
         with pytest.raises(KeyError):
             task.render_template('{{ var.json.get("missing_variable") }}', 
context)
 
-    def test_execute_callback(self, get_dummy_dag):
+    def test_execute_callback(self, create_dummy_dag):
         called = False
 
         def on_execute_callable(context):
@@ -1601,7 +1568,7 @@ class TestTaskInstance:
             called = True
             assert context['dag_run'].dag_id == 'test_dagrun_execute_callback'
 
-        _, task = get_dummy_dag(
+        _, task = create_dummy_dag(
             'test_execute_callback',
             on_execute_callback=on_execute_callable,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
@@ -1656,12 +1623,12 @@ class TestTaskInstance:
         assert not completed
         ti.log.exception.assert_called_once_with(expected_message)
 
-    def test_handle_failure(self, get_dummy_dag):
+    def test_handle_failure(self, create_dummy_dag):
         start_date = timezone.datetime(2016, 6, 1)
 
         mock_on_failure_1 = mock.MagicMock()
         mock_on_retry_1 = mock.MagicMock()
-        dag, task1 = get_dummy_dag(
+        dag, task1 = create_dummy_dag(
             dag_id="test_handle_failure",
             schedule_interval=None,
             start_date=start_date,
@@ -1777,8 +1744,8 @@ class TestTaskInstance:
         assert ti.state == State.SUCCESS
 
     @patch.object(Stats, 'incr')
-    def test_task_stats(self, stats_mock, get_dummy_dag):
-        dag, op = get_dummy_dag(
+    def test_task_stats(self, stats_mock, create_dummy_dag):
+        dag, op = create_dummy_dag(
             'test_task_start_end_stats',
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
         )
@@ -1942,8 +1909,8 @@ class TestTaskInstance:
 
             render_k8s_pod_yaml.assert_called_once()
 
-    def test_set_state_up_for_retry(self, get_dummy_dag):
-        dag, op1 = get_dummy_dag('dag')
+    def test_set_state_up_for_retry(self, create_dummy_dag):
+        dag, op1 = create_dummy_dag('dag')
 
         ti = TI(task=op1, execution_date=timezone.utcnow(), 
state=State.RUNNING)
         start_date = timezone.utcnow()
@@ -2072,8 +2039,8 @@ class TestRunRawTaskQueriesCount:
             (7, True),
         ],
     )
-    def test_execute_queries_count(self, expected_query_count, mark_success, 
get_dummy_dag):
-        _, task = get_dummy_dag()
+    def test_execute_queries_count(self, expected_query_count, mark_success, 
create_dummy_dag):
+        _, task = create_dummy_dag()
         with create_session() as session:
 
             ti = TI(task=task, execution_date=datetime.datetime.now())
@@ -2091,8 +2058,8 @@ class TestRunRawTaskQueriesCount:
         with assert_queries_count(expected_query_count_based_on_db):
             ti._run_raw_task(mark_success=mark_success)
 
-    def test_execute_queries_count_store_serialized(self, get_dummy_dag):
-        _, task = get_dummy_dag()
+    def test_execute_queries_count_store_serialized(self, create_dummy_dag):
+        _, task = create_dummy_dag()
         with create_session() as session:
             ti = TI(task=task, execution_date=datetime.datetime.now())
             ti.state = State.RUNNING
@@ -2105,9 +2072,9 @@ class TestRunRawTaskQueriesCount:
         with assert_queries_count(expected_query_count_based_on_db):
             ti._run_raw_task()
 
-    def test_operator_field_with_serialization(self, get_dummy_dag):
+    def test_operator_field_with_serialization(self, create_dummy_dag):
 
-        _, task = get_dummy_dag()
+        _, task = create_dummy_dag()
         assert task.task_type == 'DummyOperator'
 
         # Verify that ti.operator field renders correctly "without" 
Serialization

Reply via email to