This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 5c1e09c Improve `dag_maker` fixture (#17324) 5c1e09c is described below commit 5c1e09cafacea922b9281e901db7da7cadb3e9be Author: Ephraim Anierobi <splendidzig...@gmail.com> AuthorDate: Mon Aug 2 07:37:40 2021 +0100 Improve `dag_maker` fixture (#17324) This PR improves the dag_maker fixture to enable creation of dagrun, dag and dag_model separately Co-authored-by: Tzu-ping Chung <uranu...@gmail.com> --- tests/conftest.py | 53 +++++----- tests/jobs/test_backfill_job.py | 204 ++++++++++++++++++++------------------ tests/jobs/test_local_task_job.py | 103 ++++++++----------- 3 files changed, 175 insertions(+), 185 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 896e32a..48ac9b2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -428,8 +428,9 @@ def app(): @pytest.fixture def dag_maker(request): - from airflow.models import DAG + from airflow.models import DAG, DagModel from airflow.utils import timezone + from airflow.utils.session import provide_session from airflow.utils.state import State DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -444,33 +445,39 @@ def dag_maker(request): dag.__exit__(type, value, traceback) if type is None: dag.clear() - self.dag_run = dag.create_dagrun( - run_id=self.kwargs.get("run_id", "test"), - state=self.kwargs.get('state', State.RUNNING), - execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']), - start_date=self.kwargs['start_date'], - ) + + @provide_session + def make_dagmodel(self, session=None, **kwargs): + dag = self.dag + defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, is_active=True) + kwargs = {**defaults, **kwargs} + dag_model = DagModel(**kwargs) + session.add(dag_model) + session.flush() + return dag_model + + def create_dagrun(self, **kwargs): + dag = self.dag + defaults = dict( + run_id='test', + state=State.RUNNING, + execution_date=self.start_date, + start_date=self.start_date, + ) + kwargs = {**defaults, **kwargs} + self.dag_run = dag.create_dagrun(**kwargs) + return self.dag_run def __call__(self, dag_id='test_dag', **kwargs): self.kwargs = kwargs - if "start_date" not in kwargs: + self.start_date = self.kwargs.get('start_date', None) + if not self.start_date: if hasattr(request.module, 'DEFAULT_DATE'): - kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE') + self.start_date = getattr(request.module, 'DEFAULT_DATE') else: - kwargs['start_date'] = DEFAULT_DATE - dagrun_fields_not_in_dag = [ - 'state', - 'execution_date', - 'run_type', - 'queued_at', - "run_id", - "creating_job_id", - "external_trigger", - "last_scheduling_decision", - "dag_hash", - ] - kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag} - self.dag = DAG(dag_id, **kwargs) + self.start_date = DEFAULT_DATE + self.kwargs['start_date'] = self.start_date + self.dag = DAG(dag_id, **self.kwargs) return self return DagFactory() diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index c110e63..d70606a 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -46,7 +46,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.types import DagRunType -from tests.test_utils.db import clear_db_pools, clear_db_runs, set_default_pool_slots +from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots from tests.test_utils.mock_executor import MockExecutor logger = logging.getLogger(__name__) @@ -59,44 +59,10 @@ def dag_bag(): return DagBag(include_examples=True) -@pytest.fixture -def get_dummy_dag_and_run(dag_maker): - def _get_dummy_dag_and_run( - dag_id='test_dag', pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None, task_id='op', **kwargs - ): - with dag_maker(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag: - DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency) - - return dag, dag_maker.dag_run - - return _get_dummy_dag_and_run - - -@pytest.fixture -def get_dag_test_max_active_limits(dag_maker): - def _get_dag_test_max_active_limits(dag_id='test_dag', max_active_runs=1, **kwargs): - with dag_maker( - dag_id=dag_id, - start_date=DEFAULT_DATE, - schedule_interval="@hourly", - max_active_runs=max_active_runs, - **kwargs, - ) as dag: - op1 = DummyOperator(task_id='leave1') - op2 = DummyOperator(task_id='leave2') - op3 = DummyOperator(task_id='upstream_level_1') - op4 = DummyOperator(task_id='upstream_level_2') - - op1 >> op2 >> op3 - op4 >> op3 - return dag, dag_maker.dag_run - - return _get_dag_test_max_active_limits - - class TestBackfillJob: @staticmethod def clean_db(): + clear_db_dags() clear_db_runs() clear_db_pools() @@ -106,6 +72,20 @@ class TestBackfillJob: self.parser = cli_parser.get_parser() self.dagbag = dag_bag + def _get_dummy_dag( + self, + dag_maker_fixture, + dag_id='test_dag', + pool=Pool.DEFAULT_POOL_NAME, + task_concurrency=None, + task_id='op', + **kwargs, + ): + with dag_maker_fixture(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag: + DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency) + + return dag + def _times_called_with(self, method, class_): count = 0 for args in method.call_args_list: @@ -113,8 +93,9 @@ class TestBackfillJob: count += 1 return count - def test_unfinished_dag_runs_set_to_failed(self, get_dummy_dag_and_run): - dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag') + def test_unfinished_dag_runs_set_to_failed(self, dag_maker): + dag = self._get_dummy_dag(dag_maker) + dag_run = dag_maker.create_dagrun() job = BackfillJob( dag=dag, @@ -129,8 +110,9 @@ class TestBackfillJob: assert State.FAILED == dag_run.state - def test_dag_run_with_finished_tasks_set_to_success(self, get_dummy_dag_and_run): - dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag') + def test_dag_run_with_finished_tasks_set_to_success(self, dag_maker): + dag = self._get_dummy_dag(dag_maker) + dag_run = dag_maker.create_dagrun() for ti in dag_run.get_task_instances(): ti.set_state(State.SUCCESS) @@ -289,8 +271,9 @@ class TestBackfillJob: for task_id in expected_execution_order ] == executor.sorted_tasks - def test_backfill_conf(self, get_dummy_dag_and_run): - dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_conf') + def test_backfill_conf(self, dag_maker): + dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_conf') + dag_maker.create_dagrun() executor = MockExecutor() @@ -312,12 +295,14 @@ class TestBackfillJob: assert conf_ == dr[0].conf @patch('airflow.jobs.backfill_job.BackfillJob.log') - def test_backfill_respect_task_concurrency_limit(self, mock_log, get_dummy_dag_and_run): + def test_backfill_respect_task_concurrency_limit(self, mock_log, dag_maker): task_concurrency = 2 - dag, _ = get_dummy_dag_and_run( + dag = self._get_dummy_dag( + dag_maker, dag_id='test_backfill_respect_task_concurrency_limit', task_concurrency=task_concurrency, ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -364,9 +349,9 @@ class TestBackfillJob: assert times_task_concurrency_limit_reached_in_debug > 0 @patch('airflow.jobs.backfill_job.BackfillJob.log') - def test_backfill_respect_dag_concurrency_limit(self, mock_log, get_dummy_dag_and_run): - - dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_respect_concurrency_limit') + def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker): + dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_respect_concurrency_limit') + dag_maker.create_dagrun() dag.max_active_tasks = 2 executor = MockExecutor() @@ -415,11 +400,12 @@ class TestBackfillJob: assert times_dag_concurrency_limit_reached_in_debug > 0 @patch('airflow.jobs.backfill_job.BackfillJob.log') - def test_backfill_respect_default_pool_limit(self, mock_log, get_dummy_dag_and_run): + def test_backfill_respect_default_pool_limit(self, mock_log, dag_maker): default_pool_slots = 2 set_default_pool_slots(default_pool_slots) - dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_with_no_pool_limit') + dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_with_no_pool_limit') + dag_maker.create_dagrun() executor = MockExecutor() @@ -469,11 +455,13 @@ class TestBackfillJob: assert 0 == times_task_concurrency_limit_reached_in_debug assert times_pool_limit_reached_in_debug > 0 - def test_backfill_pool_not_found(self, get_dummy_dag_and_run): - dag, _ = get_dummy_dag_and_run( + def test_backfill_pool_not_found(self, dag_maker): + dag = self._get_dummy_dag( + dag_maker, dag_id='test_backfill_pool_not_found', pool='king_pool', ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -490,7 +478,7 @@ class TestBackfillJob: return @patch('airflow.jobs.backfill_job.BackfillJob.log') - def test_backfill_respect_pool_limit(self, mock_log, get_dummy_dag_and_run): + def test_backfill_respect_pool_limit(self, mock_log, dag_maker): session = settings.Session() slots = 2 @@ -501,10 +489,12 @@ class TestBackfillJob: session.add(pool) session.commit() - dag, _ = get_dummy_dag_and_run( + dag = self._get_dummy_dag( + dag_maker, dag_id='test_backfill_respect_pool_limit', pool=pool.pool, ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -550,10 +540,11 @@ class TestBackfillJob: assert 0 == times_dag_concurrency_limit_reached_in_debug assert times_pool_limit_reached_in_debug > 0 - def test_backfill_run_rescheduled(self, get_dummy_dag_and_run): - dag, _ = get_dummy_dag_and_run( - dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1" + def test_backfill_run_rescheduled(self, dag_maker): + dag = self._get_dummy_dag( + dag_maker, dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1" ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -581,10 +572,11 @@ class TestBackfillJob: ti.refresh_from_db() assert ti.state == State.SUCCESS - def test_backfill_rerun_failed_tasks(self, get_dummy_dag_and_run): - dag, _ = get_dummy_dag_and_run( - dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1" + def test_backfill_rerun_failed_tasks(self, dag_maker): + dag = self._get_dummy_dag( + dag_maker, dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1" ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -614,12 +606,11 @@ class TestBackfillJob: def test_backfill_rerun_upstream_failed_tasks(self, dag_maker): - with dag_maker( - dag_id='test_backfill_rerun_upstream_failed', start_date=DEFAULT_DATE, schedule_interval='@daily' - ) as dag: + with dag_maker(dag_id='test_backfill_rerun_upstream_failed', schedule_interval='@daily') as dag: op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1') op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2') op1.set_upstream(op2) + dag_maker.create_dagrun() executor = MockExecutor() @@ -647,10 +638,11 @@ class TestBackfillJob: ti.refresh_from_db() assert ti.state == State.SUCCESS - def test_backfill_rerun_failed_tasks_without_flag(self, get_dummy_dag_and_run): - dag, _ = get_dummy_dag_and_run( - dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1' + def test_backfill_rerun_failed_tasks_without_flag(self, dag_maker): + dag = self._get_dummy_dag( + dag_maker, dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1' ) + dag_maker.create_dagrun() executor = MockExecutor() @@ -680,7 +672,6 @@ class TestBackfillJob: def test_backfill_retry_intermittent_failed_task(self, dag_maker): with dag_maker( dag_id='test_intermittent_failure_job', - start_date=DEFAULT_DATE, schedule_interval="@daily", default_args={ 'retries': 2, @@ -688,6 +679,7 @@ class TestBackfillJob: }, ) as dag: task1 = DummyOperator(task_id="task1") + dag_maker.create_dagrun() executor = MockExecutor(parallelism=16) executor.mock_task_results[ @@ -707,7 +699,6 @@ class TestBackfillJob: def test_backfill_retry_always_failed_task(self, dag_maker): with dag_maker( dag_id='test_always_failure_job', - start_date=DEFAULT_DATE, schedule_interval="@daily", default_args={ 'retries': 1, @@ -715,6 +706,7 @@ class TestBackfillJob: }, ) as dag: task1 = DummyOperator(task_id="task1") + dag_maker.create_dagrun() executor = MockExecutor(parallelism=16) executor.mock_task_results[ @@ -734,7 +726,6 @@ class TestBackfillJob: with dag_maker( dag_id='test_backfill_ordered_concurrent_execute', - start_date=DEFAULT_DATE, schedule_interval="@daily", ) as dag: op1 = DummyOperator(task_id='leave1') @@ -747,6 +738,7 @@ class TestBackfillJob: op1.set_downstream(op3) op4.set_downstream(op5) op3.set_downstream(op4) + dag_maker.create_dagrun() executor = MockExecutor(parallelism=16) job = BackfillJob( @@ -881,10 +873,29 @@ class TestBackfillJob: parsed_args = self.parser.parse_args(args) assert 0.5 == parsed_args.delay_on_limit - def test_backfill_max_limit_check_within_limit(self, get_dag_test_max_active_limits): - dag, _ = get_dag_test_max_active_limits( - dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16 + def _get_dag_test_max_active_limits( + self, dag_maker_fixture, dag_id='test_dag', max_active_runs=1, **kwargs + ): + with dag_maker_fixture( + dag_id=dag_id, + schedule_interval="@hourly", + max_active_runs=max_active_runs, + **kwargs, + ) as dag: + op1 = DummyOperator(task_id='leave1') + op2 = DummyOperator(task_id='leave2') + op3 = DummyOperator(task_id='upstream_level_1') + op4 = DummyOperator(task_id='upstream_level_2') + + op1 >> op2 >> op3 + op4 >> op3 + return dag + + def test_backfill_max_limit_check_within_limit(self, dag_maker): + dag = self._get_dag_test_max_active_limits( + dag_maker, dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16 ) + dag_maker.create_dagrun() start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE @@ -898,7 +909,7 @@ class TestBackfillJob: assert 2 == len(dagruns) assert all(run.state == State.SUCCESS for run in dagruns) - def test_backfill_max_limit_check(self, get_dag_test_max_active_limits): + def test_backfill_max_limit_check(self, dag_maker): dag_id = 'test_backfill_max_limit_check' run_id = 'test_dag_run' start_date = DEFAULT_DATE - datetime.timedelta(hours=1) @@ -911,9 +922,12 @@ class TestBackfillJob: # this session object is different than the one in the main thread with create_session() as thread_session: try: - dag, _ = get_dag_test_max_active_limits( - # Existing dagrun that is not within the backfill range + dag = self._get_dag_test_max_active_limits( + dag_maker, dag_id=dag_id, + ) + dag_maker.create_dagrun( + # Existing dagrun that is not within the backfill range run_id=run_id, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1), ) @@ -960,11 +974,14 @@ class TestBackfillJob: finally: dag_run_created_cond.release() - def test_backfill_max_limit_check_no_count_existing(self, get_dag_test_max_active_limits): + def test_backfill_max_limit_check_no_count_existing(self, dag_maker): start_date = DEFAULT_DATE end_date = DEFAULT_DATE # Existing dagrun that is within the backfill range - dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_no_count_existing') + dag = self._get_dag_test_max_active_limits( + dag_maker, dag_id='test_backfill_max_limit_check_no_count_existing' + ) + dag_maker.create_dagrun() executor = MockExecutor() job = BackfillJob( @@ -980,8 +997,11 @@ class TestBackfillJob: assert 1 == len(dagruns) assert State.SUCCESS == dagruns[0].state - def test_backfill_max_limit_check_complete_loop(self, get_dag_test_max_active_limits): - dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_complete_loop') + def test_backfill_max_limit_check_complete_loop(self, dag_maker): + dag = self._get_dag_test_max_active_limits( + dag_maker, dag_id='test_backfill_max_limit_check_complete_loop' + ) + dag_maker.create_dagrun() start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE @@ -1003,9 +1023,6 @@ class TestBackfillJob: with dag_maker( 'test_sub_set_subdag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}, - execution_date=DEFAULT_DATE, ) as dag: op1 = DummyOperator(task_id='leave1') op2 = DummyOperator(task_id='leave2') @@ -1018,7 +1035,7 @@ class TestBackfillJob: op4.set_downstream(op5) op3.set_downstream(op4) - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() executor = MockExecutor() sub_dag = dag.partial_subset( @@ -1043,9 +1060,6 @@ class TestBackfillJob: def test_backfill_fill_blanks(self, dag_maker): with dag_maker( 'test_backfill_fill_blanks', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}, - execution_date=DEFAULT_DATE, ) as dag: op1 = DummyOperator(task_id='op1') op2 = DummyOperator(task_id='op2') @@ -1054,7 +1068,7 @@ class TestBackfillJob: op5 = DummyOperator(task_id='op5') op6 = DummyOperator(task_id='op6') - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() executor = MockExecutor() @@ -1231,11 +1245,9 @@ class TestBackfillJob: dag.clear() def test_update_counters(self, dag_maker): - with dag_maker( - dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE - ) as dag: - task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - dr = dag_maker.dag_run + with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag: + task1 = DummyOperator(task_id='dummy', owner='airflow') + dr = dag_maker.create_dagrun() job = BackfillJob(dag=dag) session = settings.Session() @@ -1380,9 +1392,7 @@ class TestBackfillJob: states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE] tasks = [] - with dag_maker( - dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily", run_id='test1' - ) as dag: + with dag_maker(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily") as dag: for i in range(len(states)): task_id = f"{prefix}_task_{i}" task = DummyOperator(task_id=task_id) @@ -1392,7 +1402,7 @@ class TestBackfillJob: job = BackfillJob(dag=dag) # create dagruns - dr1 = dag_maker.dag_run + dr1 = dag_maker.create_dagrun() dr2 = dag.create_dagrun(run_id='test2', state=State.SUCCESS) # create taskinstances and set states @@ -1445,15 +1455,13 @@ class TestBackfillJob: dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily', - state=State.SUCCESS, - run_id='test1', ) as dag: DummyOperator(task_id=task_id, dag=dag) job = BackfillJob(dag=dag) session = settings.Session() # make two dagruns, only reset for one - dr1 = dag_maker.dag_run # Already created in dag_maker with state=SUCCESS + dr1 = dag_maker.create_dagrun(state=State.SUCCESS) dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING) ti1 = dr1.get_task_instances(session=session)[0] ti2 = dr2.get_task_instances(session=session)[0] diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 7aa596c..1d6d572 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -27,14 +27,12 @@ from unittest import mock from unittest.mock import patch import pytest -from parameterized import parameterized from airflow import settings from airflow.exceptions import AirflowException, AirflowFailException from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.local_task_job import LocalTaskJob from airflow.jobs.scheduler_job import SchedulerJob -from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.taskinstance import TaskInstance from airflow.operators.dummy import DummyOperator @@ -73,10 +71,19 @@ def clear_db_class(): db.clear_db_task_fail() +@pytest.fixture(scope='module') +def dagbag(): + return DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + + @pytest.mark.usefixtures('clear_db_class', 'clear_db') class TestLocalTaskJob: @pytest.fixture(autouse=True) - def set_instance_attrs(self): + def set_instance_attrs(self, dagbag): + self.dagbag = dagbag with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep: yield @@ -92,12 +99,10 @@ class TestLocalTaskJob: of LocalTaskJob can be assigned with proper values without intervention """ - with dag_maker( - 'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'} - ): + with dag_maker('test_localtaskjob_essential_attr'): op1 = DummyOperator(task_id='op1') - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() ti = dr.get_task_instance(task_id=op1.task_id) @@ -116,7 +121,7 @@ class TestLocalTaskJob: with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1') - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.hostname = "blablabla" @@ -148,7 +153,7 @@ class TestLocalTaskJob: session = settings.Session() with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1', run_as_user='myuser') - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.pid = 2 @@ -190,7 +195,7 @@ class TestLocalTaskJob: session = settings.Session() with dag_maker('test_localtaskjob_heartbeat'): op1 = DummyOperator(task_id='op1') - dr = dag_maker.dag_run + dr = dag_maker.create_dagrun() ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.RUNNING ti.pid = 2 @@ -234,13 +239,10 @@ class TestLocalTaskJob: dag_id = 'test_heartbeat_failed_fast' task_id = 'test_heartbeat_failed_fast_op' with create_session() as session: - dagbag = DagBag( - dag_folder=TEST_DAG_FOLDER, - include_examples=False, - ) + dag_id = 'test_heartbeat_failed_fast' task_id = 'test_heartbeat_failed_fast_op' - dag = dagbag.get_dag(dag_id) + dag = self.dagbag.get_dag(dag_id) task = dag.get_task(task_id) dag.create_dagrun( @@ -276,11 +278,7 @@ class TestLocalTaskJob: Test that ensures that mark_success in the UI doesn't cause the task to fail, and that the task exits """ - dagbag = DagBag( - dag_folder=TEST_DAG_FOLDER, - include_examples=False, - ) - dag = dagbag.dags.get('test_mark_success') + dag = self.dagbag.dags.get('test_mark_success') task = dag.get_task('task1') session = settings.Session() @@ -316,11 +314,7 @@ class TestLocalTaskJob: def test_localtaskjob_double_trigger(self): - dagbag = DagBag( - dag_folder=TEST_DAG_FOLDER, - include_examples=False, - ) - dag = dagbag.dags.get('test_localtaskjob_double_trigger') + dag = self.dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() @@ -356,11 +350,8 @@ class TestLocalTaskJob: @pytest.mark.quarantined def test_localtaskjob_maintain_heart_rate(self): - dagbag = DagBag( - dag_folder=TEST_DAG_FOLDER, - include_examples=False, - ) - dag = dagbag.dags.get('test_localtaskjob_double_trigger') + + dag = self.dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() @@ -439,6 +430,7 @@ class TestLocalTaskJob: python_callable=task_function, on_failure_callback=check_failure, ) + dag_maker.create_dagrun() ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -480,6 +472,7 @@ class TestLocalTaskJob: python_callable=task_function, on_failure_callback=failure_callback, ) + dag_maker.create_dagrun() ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -653,7 +646,8 @@ class TestLocalTaskJob: assert task_terminated_externally.value == 1 assert not process.is_alive() - @parameterized.expand( + @pytest.mark.parametrize( + "conf, dependencies, init_state, first_run_state, second_run_state, error_message", [ ( {('scheduler', 'schedule_after_task_execution'): 'True'}, @@ -687,27 +681,17 @@ class TestLocalTaskJob: None, "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", ), - ] + ], ) def test_fast_follow( - self, conf, dependencies, init_state, first_run_state, second_run_state, error_message + self, conf, dependencies, init_state, first_run_state, second_run_state, error_message, dag_maker ): with conf_vars(conf): session = settings.Session() - dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) - - dag_model = DagModel( - dag_id=dag.dag_id, - next_dagrun=dag.start_date, - is_active=True, - ) - session.add(dag_model) - session.flush() - python_callable = lambda: True - with dag: + with dag_maker('test_dagrun_fast_follow') as dag: task_a = PythonOperator(task_id='A', python_callable=python_callable) task_b = PythonOperator(task_id='B', python_callable=python_callable) task_c = PythonOperator(task_id='C', python_callable=python_callable) @@ -716,6 +700,8 @@ class TestLocalTaskJob: for upstream, downstream in dependencies.items(): dag.set_dependency(upstream, downstream) + dag_maker.make_dagmodel() + scheduler_job = SchedulerJob(subdir=os.devnull) scheduler_job.dagbag.bag_dag(dag, root_dag=dag) @@ -851,34 +837,24 @@ class TestLocalTaskJob: assert retry_callback_called.value == 1 assert task_terminated_externally.value == 1 - def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self): + def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker): """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" - dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) - op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True) + with dag_maker(dag_id='test_dags') as dag: + op1 = PythonOperator(task_id='dummy', python_callable=lambda: True) session = settings.Session() - orm_dag = DagModel( - dag_id=dag.dag_id, + dag_maker.make_dagmodel( has_task_concurrency_limits=False, - next_dagrun=dag.start_date, next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), is_active=True, is_paused=True, ) - session.add(orm_dag) - session.flush() # Write Dag to DB dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) dagbag.bag_dag(dag, root_dag=dag) dagbag.sync_to_db() - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - state=State.RUNNING, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session, - ) + dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) assert dr.state == State.RUNNING ti = TaskInstance(op1, dr.execution_date) @@ -901,13 +877,12 @@ def clean_db_helper(): class TestLocalTaskJobPerformance: @pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]]) # type: ignore @mock.patch("airflow.jobs.local_task_job.get_task_runner") - def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes): + def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes, dag_maker): unique_prefix = str(uuid.uuid4()) - dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE) - task = DummyOperator(task_id='test_state_succeeded1', dag=dag) + with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'): + task = DummyOperator(task_id='test_state_succeeded1') - dag.clear() - dag.create_dagrun(run_id=unique_prefix, execution_date=DEFAULT_DATE, state=State.NONE) + dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)