This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v2-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 5bcc620e3f6a5adc17222f598fdc065a8cca53a2 Author: Ephraim Anierobi <[email protected]> AuthorDate: Fri Jun 25 05:36:56 2021 +0100 Move DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py (#16581) This change moves DagFileProcessor and DagFileProcessorProcess out of scheduler_job.py. Also, dag_processing.py was moved out of airflow/utils. (cherry picked from commit 88ee2aa7ddf91799f25add9c57e1ea128de2b7aa) --- .github/boring-cyborg.yml | 2 +- airflow/dag_processing/__init__.py | 16 + .../manager.py} | 0 airflow/dag_processing/processor.py | 650 ++++++++++++++++++ airflow/jobs/scheduler_job.py | 619 +---------------- tests/dag_processing/__init__.py | 16 + .../test_manager.py} | 16 +- tests/dag_processing/test_processor.py | 749 +++++++++++++++++++++ tests/jobs/test_scheduler_job.py | 700 +------------------ tests/test_utils/perf/perf_kit/python.py | 2 +- tests/test_utils/perf/perf_kit/sqlalchemy.py | 2 +- 11 files changed, 1456 insertions(+), 1316 deletions(-) diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index d5f7632..8ae0532 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -157,7 +157,7 @@ labelPRBasedOnFilePath: - airflow/executors/**/* - airflow/jobs/**/* - airflow/task/task_runner/**/* - - airflow/utils/dag_processing.py + - airflow/dag_processing/**/* - docs/apache-airflow/executor/**/* - docs/apache-airflow/scheduler.rst - tests/executors/**/* diff --git a/airflow/dag_processing/__init__.py b/airflow/dag_processing/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/airflow/dag_processing/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/utils/dag_processing.py b/airflow/dag_processing/manager.py similarity index 100% rename from airflow/utils/dag_processing.py rename to airflow/dag_processing/manager.py diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py new file mode 100644 index 0000000..44dc5f2 --- /dev/null +++ b/airflow/dag_processing/processor.py @@ -0,0 +1,650 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import logging +import multiprocessing +import os +import signal +import threading +from contextlib import redirect_stderr, redirect_stdout, suppress +from datetime import timedelta +from multiprocessing.connection import Connection as MultiprocessingConnection +from typing import List, Optional, Set, Tuple + +from setproctitle import setproctitle # pylint: disable=no-name-in-module +from sqlalchemy import func, or_ +from sqlalchemy.orm.session import Session + +from airflow import models, settings +from airflow.configuration import conf +from airflow.dag_processing.manager import AbstractDagFileProcessorProcess +from airflow.exceptions import AirflowException, TaskNotFound +from airflow.models import DAG, DagModel, SlaMiss, errors +from airflow.models.dagbag import DagBag +from airflow.stats import Stats +from airflow.utils import timezone +from airflow.utils.callback_requests import ( + CallbackRequest, + DagCallbackRequest, + SlaCallbackRequest, + TaskCallbackRequest, +) +from airflow.utils.email import get_email_address_list, send_email +from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context +from airflow.utils.mixins import MultiprocessingStartMethodMixin +from airflow.utils.session import provide_session +from airflow.utils.state import State + +TI = models.TaskInstance + + +class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin): + """Runs DAG processing in a separate process using DagFileProcessor + + :param file_path: a Python file containing Airflow DAG definitions + :type file_path: str + :param pickle_dags: whether to serialize the DAG objects to the DB + :type pickle_dags: bool + :param dag_ids: If specified, only look at these DAG ID's + :type dag_ids: List[str] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] + """ + + # Counter that increments every time an instance of this class is created + class_creation_counter = 0 + + def __init__( + self, + file_path: str, + pickle_dags: bool, + dag_ids: Optional[List[str]], + callback_requests: List[CallbackRequest], + ): + super().__init__() + self._file_path = file_path + self._pickle_dags = pickle_dags + self._dag_ids = dag_ids + self._callback_requests = callback_requests + + # The process that was launched to process the given . + self._process: Optional[multiprocessing.process.BaseProcess] = None + # The result of DagFileProcessor.process_file(file_path). + self._result: Optional[Tuple[int, int]] = None + # Whether the process is done running. + self._done = False + # When the process started. + self._start_time: Optional[datetime.datetime] = None + # This ID is use to uniquely name the process / thread that's launched + # by this processor instance + self._instance_id = DagFileProcessorProcess.class_creation_counter + + self._parent_channel: Optional[MultiprocessingConnection] = None + DagFileProcessorProcess.class_creation_counter += 1 + + @property + def file_path(self) -> str: + return self._file_path + + @staticmethod + def _run_file_processor( + result_channel: MultiprocessingConnection, + parent_channel: MultiprocessingConnection, + file_path: str, + pickle_dags: bool, + dag_ids: Optional[List[str]], + thread_name: str, + callback_requests: List[CallbackRequest], + ) -> None: + """ + Process the given file. + + :param result_channel: the connection to use for passing back the result + :type result_channel: multiprocessing.Connection + :param parent_channel: the parent end of the channel to close in the child + :type parent_channel: multiprocessing.Connection + :param file_path: the file to process + :type file_path: str + :param pickle_dags: whether to pickle the DAGs found in the file and + save them to the DB + :type pickle_dags: bool + :param dag_ids: if specified, only examine DAG ID's that are + in this list + :type dag_ids: list[str] + :param thread_name: the name to use for the process that is launched + :type thread_name: str + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] + :return: the process that was launched + :rtype: multiprocessing.Process + """ + # This helper runs in the newly created process + log: logging.Logger = logging.getLogger("airflow.processor") + + # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in + # the child, else it won't get closed properly until we exit. + log.info("Closing parent pipe") + + parent_channel.close() + del parent_channel + + set_context(log, file_path) + setproctitle(f"airflow scheduler - DagFileProcessor {file_path}") + + try: + # redirect stdout/stderr to log + with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr( + StreamLogWriter(log, logging.WARN) + ), Stats.timer() as timer: + # Re-configure the ORM engine as there are issues with multiple processes + settings.configure_orm() + + # Change the thread name to differentiate log lines. This is + # really a separate process, but changing the name of the + # process doesn't work, so changing the thread name instead. + threading.current_thread().name = thread_name + + log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) + dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) + result: Tuple[int, int] = dag_file_processor.process_file( + file_path=file_path, + pickle_dags=pickle_dags, + callback_requests=callback_requests, + ) + result_channel.send(result) + log.info("Processing %s took %.3f seconds", file_path, timer.duration) + except Exception: # pylint: disable=broad-except + # Log exceptions through the logging framework. + log.exception("Got an exception! Propagating...") + raise + finally: + # We re-initialized the ORM within this Process above so we need to + # tear it down manually here + settings.dispose_orm() + + result_channel.close() + + def start(self) -> None: + """Launch the process and start processing the DAG.""" + start_method = self._get_multiprocessing_start_method() + context = multiprocessing.get_context(start_method) + + _parent_channel, _child_channel = context.Pipe(duplex=False) + process = context.Process( + target=type(self)._run_file_processor, + args=( + _child_channel, + _parent_channel, + self.file_path, + self._pickle_dags, + self._dag_ids, + f"DagFileProcessor{self._instance_id}", + self._callback_requests, + ), + name=f"DagFileProcessor{self._instance_id}-Process", + ) + self._process = process + self._start_time = timezone.utcnow() + process.start() + + # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it + # from closing in some cases + _child_channel.close() + del _child_channel + + # Don't store it on self until after we've started the child process - we don't want to keep it from + # getting GCd/closed + self._parent_channel = _parent_channel + + def kill(self) -> None: + """Kill the process launched to process the file, and ensure consistent state.""" + if self._process is None: + raise AirflowException("Tried to kill before starting!") + self._kill_process() + + def terminate(self, sigkill: bool = False) -> None: + """ + Terminate (and then kill) the process launched to process the file. + + :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work. + :type sigkill: bool + """ + if self._process is None or self._parent_channel is None: + raise AirflowException("Tried to call terminate before starting!") + + self._process.terminate() + # Arbitrarily wait 5s for the process to die + with suppress(TimeoutError): + self._process._popen.wait(5) # type: ignore # pylint: disable=protected-access + if sigkill: + self._kill_process() + self._parent_channel.close() + + def _kill_process(self) -> None: + if self._process is None: + raise AirflowException("Tried to kill process before starting!") + + if self._process.is_alive() and self._process.pid: + self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid) + os.kill(self._process.pid, signal.SIGKILL) + if self._parent_channel: + self._parent_channel.close() + + @property + def pid(self) -> int: + """ + :return: the PID of the process launched to process the given file + :rtype: int + """ + if self._process is None or self._process.pid is None: + raise AirflowException("Tried to get PID before starting!") + return self._process.pid + + @property + def exit_code(self) -> Optional[int]: + """ + After the process is finished, this can be called to get the return code + + :return: the exit code of the process + :rtype: int + """ + if self._process is None: + raise AirflowException("Tried to get exit code before starting!") + if not self._done: + raise AirflowException("Tried to call retcode before process was finished!") + return self._process.exitcode + + @property + def done(self) -> bool: + """ + Check if the process launched to process this file is done. + + :return: whether the process is finished running + :rtype: bool + """ + if self._process is None or self._parent_channel is None: + raise AirflowException("Tried to see if it's done before starting!") + + if self._done: + return True + + if self._parent_channel.poll(): + try: + self._result = self._parent_channel.recv() + self._done = True + self.log.debug("Waiting for %s", self._process) + self._process.join() + self._parent_channel.close() + return True + except EOFError: + # If we get an EOFError, it means the child end of the pipe has been closed. This only happens + # in the finally block. But due to a possible race condition, the process may have not yet + # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a + # "suitable" timeout. + self._done = True + # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable. + self._process.join(timeout=5) + if self._process.is_alive(): + # Didn't shut down cleanly - kill it + self._kill_process() + + if not self._process.is_alive(): + self._done = True + self.log.debug("Waiting for %s", self._process) + self._process.join() + self._parent_channel.close() + return True + + return False + + @property + def result(self) -> Optional[Tuple[int, int]]: + """ + :return: result of running DagFileProcessor.process_file() + :rtype: tuple[int, int] or None + """ + if not self.done: + raise AirflowException("Tried to get the result before it's done!") + return self._result + + @property + def start_time(self) -> datetime.datetime: + """ + :return: when this started to process the file + :rtype: datetime + """ + if self._start_time is None: + raise AirflowException("Tried to get start time before it started!") + return self._start_time + + @property + def waitable_handle(self): + return self._process.sentinel + + +class DagFileProcessor(LoggingMixin): + """ + Process a Python file containing Airflow DAGs. + + This includes: + + 1. Execute the file and look for DAG objects in the namespace. + 2. Execute any Callbacks if passed to DagFileProcessor.process_file + 3. Serialize the DAGs and save it to DB (or update existing record in the DB). + 4. Pickle the DAG and save it to the DB (if necessary). + 5. Record any errors importing the file into ORM + + Returns a tuple of 'number of dags found' and 'the count of import errors' + + :param dag_ids: If specified, only look at these DAG ID's + :type dag_ids: List[str] + :param log: Logger to save the processing process + :type log: logging.Logger + """ + + UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE') + + def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger): + super().__init__() + self.dag_ids = dag_ids + self._log = log + + @provide_session + def manage_slas(self, dag: DAG, session: Session = None) -> None: + """ + Finding all tasks that have SLAs defined, and sending alert emails + where needed. New SLA misses are also recorded in the database. + + We are assuming that the scheduler runs often, so we only check for + tasks that should have succeeded in the past hour. + """ + self.log.info("Running SLA Checks for %s", dag.dag_id) + if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): + self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) + return + + qry = ( + session.query(TI.task_id, func.max(TI.execution_date).label('max_ti')) + .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') + .filter(TI.dag_id == dag.dag_id) + .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) + .filter(TI.task_id.in_(dag.task_ids)) + .group_by(TI.task_id) + .subquery('sq') + ) + + max_tis: List[TI] = ( + session.query(TI) + .filter( + TI.dag_id == dag.dag_id, + TI.task_id == qry.c.task_id, + TI.execution_date == qry.c.max_ti, + ) + .all() + ) + + ts = timezone.utcnow() + for ti in max_tis: + task = dag.get_task(ti.task_id) + if task.sla and not isinstance(task.sla, timedelta): + raise TypeError( + f"SLA is expected to be timedelta object, got " + f"{type(task.sla)} in {task.dag_id}:{task.task_id}" + ) + + dttm = dag.following_schedule(ti.execution_date) + while dttm < timezone.utcnow(): + following_schedule = dag.following_schedule(dttm) + if following_schedule + task.sla < timezone.utcnow(): + session.merge( + SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts) + ) + dttm = dag.following_schedule(dttm) + session.commit() + + # pylint: disable=singleton-comparison + slas: List[SlaMiss] = ( + session.query(SlaMiss) + .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa + .all() + ) + # pylint: enable=singleton-comparison + + if slas: # pylint: disable=too-many-nested-blocks + sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] + fetched_tis: List[TI] = ( + session.query(TI) + .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) + .all() + ) + blocking_tis: List[TI] = [] + for ti in fetched_tis: + if ti.task_id in dag.task_ids: + ti.task = dag.get_task(ti.task_id) + blocking_tis.append(ti) + else: + session.delete(ti) + session.commit() + + task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) + blocking_task_list = "\n".join( + ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis + ) + # Track whether email or any alert notification sent + # We consider email or the alert callback as notifications + email_sent = False + notification_sent = False + if dag.sla_miss_callback: + # Execute the alert callback + self.log.info('Calling SLA miss callback') + try: + dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) + notification_sent = True + except Exception: # pylint: disable=broad-except + self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id) + email_content = f"""\ + Here's a list of tasks that missed their SLAs: + <pre><code>{task_list}\n<code></pre> + Blocking tasks: + <pre><code>{blocking_task_list}<code></pre> + Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} + """ + + tasks_missed_sla = [] + for sla in slas: + try: + task = dag.get_task(sla.task_id) + except TaskNotFound: + # task already deleted from DAG, skip it + self.log.warning( + "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id + ) + continue + tasks_missed_sla.append(task) + + emails: Set[str] = set() + for task in tasks_missed_sla: + if task.email: + if isinstance(task.email, str): + emails |= set(get_email_address_list(task.email)) + elif isinstance(task.email, (list, tuple)): + emails |= set(task.email) + if emails: + try: + send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) + email_sent = True + notification_sent = True + except Exception: # pylint: disable=broad-except + Stats.incr('sla_email_notification_failure') + self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id) + # If we sent any notification, update the sla_miss table + if notification_sent: + for sla in slas: + sla.email_sent = email_sent + sla.notification_sent = True + session.merge(sla) + session.commit() + + @staticmethod + def update_import_errors(session: Session, dagbag: DagBag) -> None: + """ + For the DAGs in the given DagBag, record any associated import errors and clears + errors for files that no longer have them. These are usually displayed through the + Airflow UI so that users know that there are issues parsing DAGs. + + :param session: session for ORM operations + :type session: sqlalchemy.orm.session.Session + :param dagbag: DagBag containing DAGs with import errors + :type dagbag: airflow.DagBag + """ + # Clear the errors of the processed files + for dagbag_file in dagbag.file_last_changed: + session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete() + + # Add the errors of the processed files + for filename, stacktrace in dagbag.import_errors.items(): + session.add( + errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace) + ) + session.commit() + + @provide_session + def execute_callbacks( + self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None + ) -> None: + """ + Execute on failure callbacks. These objects can come from SchedulerJob or from + DagFileProcessorManager. + + :param dagbag: Dag Bag of dags + :param callback_requests: failure callbacks to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] + :param session: DB session. + """ + for request in callback_requests: + self.log.debug("Processing Callback Request: %s", request) + try: + if isinstance(request, TaskCallbackRequest): + self._execute_task_callbacks(dagbag, request) + elif isinstance(request, SlaCallbackRequest): + self.manage_slas(dagbag.dags.get(request.dag_id)) + elif isinstance(request, DagCallbackRequest): + self._execute_dag_callbacks(dagbag, request, session) + except Exception: # pylint: disable=broad-except + self.log.exception( + "Error executing %s callback for file: %s", + request.__class__.__name__, + request.full_filepath, + ) + + session.commit() + + @provide_session + def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): + dag = dagbag.dags[request.dag_id] + dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) + dag.handle_callback( + dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session + ) + + def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): + simple_ti = request.simple_task_instance + if simple_ti.dag_id in dagbag.dags: + dag = dagbag.dags[simple_ti.dag_id] + if simple_ti.task_id in dag.task_ids: + task = dag.get_task(simple_ti.task_id) + ti = TI(task, simple_ti.execution_date) + # Get properties needed for failure handling from SimpleTaskInstance. + ti.start_date = simple_ti.start_date + ti.end_date = simple_ti.end_date + ti.try_number = simple_ti.try_number + ti.state = simple_ti.state + ti.test_mode = self.UNIT_TEST_MODE + if request.is_failure_callback: + ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode) + self.log.info('Executed failure callback for %s in state %s', ti, ti.state) + + @provide_session + def process_file( + self, + file_path: str, + callback_requests: List[CallbackRequest], + pickle_dags: bool = False, + session: Session = None, + ) -> Tuple[int, int]: + """ + Process a Python file containing Airflow DAGs. + + This includes: + + 1. Execute the file and look for DAG objects in the namespace. + 2. Execute any Callbacks if passed to this method. + 3. Serialize the DAGs and save it to DB (or update existing record in the DB). + 4. Pickle the DAG and save it to the DB (if necessary). + 5. Record any errors importing the file into ORM + + :param file_path: the path to the Python file that should be executed + :type file_path: str + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest] + :param pickle_dags: whether serialize the DAGs found in the file and + save them to the db + :type pickle_dags: bool + :param session: Sqlalchemy ORM Session + :type session: Session + :return: number of dags found, count of import errors + :rtype: Tuple[int, int] + """ + self.log.info("Processing file %s for tasks to queue", file_path) + + try: + dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False) + except Exception: # pylint: disable=broad-except + self.log.exception("Failed at reloading the DAG file %s", file_path) + Stats.incr('dag_file_refresh_error', 1, 1) + return 0, 0 + + if len(dagbag.dags) > 0: + self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) + else: + self.log.warning("No viable dags retrieved from %s", file_path) + self.update_import_errors(session, dagbag) + return 0, len(dagbag.import_errors) + + self.execute_callbacks(dagbag, callback_requests) + + # Save individual DAGs in the ORM + dagbag.sync_to_db() + + if pickle_dags: + paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) + + unpaused_dags: List[DAG] = [ + dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids + ] + + for dag in unpaused_dags: + dag.pickle(session) + + # Record import errors into the ORM + try: + self.update_import_errors(session, dagbag) + except Exception: # pylint: disable=broad-except + self.log.exception("Error logging import errors!") + + return len(dagbag.dags), len(dagbag.import_errors) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index fe8e0b0..5b24e00 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -23,15 +23,11 @@ import multiprocessing import os import signal import sys -import threading import time from collections import defaultdict -from contextlib import redirect_stderr, redirect_stdout, suppress from datetime import timedelta -from multiprocessing.connection import Connection as MultiprocessingConnection from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple -from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_, tuple_ from sqlalchemy.exc import OperationalError from sqlalchemy.orm import load_only, selectinload @@ -39,10 +35,13 @@ from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings from airflow.configuration import conf -from airflow.exceptions import AirflowException, SerializedDagNotFound, TaskNotFound +from airflow.dag_processing.manager import DagFileProcessorAgent +from airflow.dag_processing.processor import DagFileProcessorProcess +from airflow.exceptions import SerializedDagNotFound from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS from airflow.jobs.base_job import BaseJob -from airflow.models import DAG, DagModel, SlaMiss, errors +from airflow.models import DAG +from airflow.models.dag import DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel @@ -50,17 +49,8 @@ from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils import timezone -from airflow.utils.callback_requests import ( - CallbackRequest, - DagCallbackRequest, - SlaCallbackRequest, - TaskCallbackRequest, -) -from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent -from airflow.utils.email import get_email_address_list, send_email +from airflow.utils.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest from airflow.utils.event_scheduler import EventScheduler -from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context -from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries from airflow.utils.session import create_session, provide_session from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks @@ -72,603 +62,6 @@ DR = models.DagRun DM = models.DagModel -class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, MultiprocessingStartMethodMixin): - """Runs DAG processing in a separate process using DagFileProcessor - - :param file_path: a Python file containing Airflow DAG definitions - :type file_path: str - :param pickle_dags: whether to serialize the DAG objects to the DB - :type pickle_dags: bool - :param dag_ids: If specified, only look at these DAG ID's - :type dag_ids: List[str] - :param callback_requests: failure callback to execute - :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] - """ - - # Counter that increments every time an instance of this class is created - class_creation_counter = 0 - - def __init__( - self, - file_path: str, - pickle_dags: bool, - dag_ids: Optional[List[str]], - callback_requests: List[CallbackRequest], - ): - super().__init__() - self._file_path = file_path - self._pickle_dags = pickle_dags - self._dag_ids = dag_ids - self._callback_requests = callback_requests - - # The process that was launched to process the given . - self._process: Optional[multiprocessing.process.BaseProcess] = None - # The result of DagFileProcessor.process_file(file_path). - self._result: Optional[Tuple[int, int]] = None - # Whether the process is done running. - self._done = False - # When the process started. - self._start_time: Optional[datetime.datetime] = None - # This ID is use to uniquely name the process / thread that's launched - # by this processor instance - self._instance_id = DagFileProcessorProcess.class_creation_counter - - self._parent_channel: Optional[MultiprocessingConnection] = None - DagFileProcessorProcess.class_creation_counter += 1 - - @property - def file_path(self) -> str: - return self._file_path - - @staticmethod - def _run_file_processor( - result_channel: MultiprocessingConnection, - parent_channel: MultiprocessingConnection, - file_path: str, - pickle_dags: bool, - dag_ids: Optional[List[str]], - thread_name: str, - callback_requests: List[CallbackRequest], - ) -> None: - """ - Process the given file. - - :param result_channel: the connection to use for passing back the result - :type result_channel: multiprocessing.Connection - :param parent_channel: the parent end of the channel to close in the child - :type parent_channel: multiprocessing.Connection - :param file_path: the file to process - :type file_path: str - :param pickle_dags: whether to pickle the DAGs found in the file and - save them to the DB - :type pickle_dags: bool - :param dag_ids: if specified, only examine DAG ID's that are - in this list - :type dag_ids: list[str] - :param thread_name: the name to use for the process that is launched - :type thread_name: str - :param callback_requests: failure callback to execute - :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] - :return: the process that was launched - :rtype: multiprocessing.Process - """ - # This helper runs in the newly created process - log: logging.Logger = logging.getLogger("airflow.processor") - - # Since we share all open FDs from the parent, we need to close the parent side of the pipe here in - # the child, else it won't get closed properly until we exit. - log.info("Closing parent pipe") - - parent_channel.close() - del parent_channel - - set_context(log, file_path) - setproctitle(f"airflow scheduler - DagFileProcessor {file_path}") - - try: - # redirect stdout/stderr to log - with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr( - StreamLogWriter(log, logging.WARN) - ), Stats.timer() as timer: - # Re-configure the ORM engine as there are issues with multiple processes - settings.configure_orm() - - # Change the thread name to differentiate log lines. This is - # really a separate process, but changing the name of the - # process doesn't work, so changing the thread name instead. - threading.current_thread().name = thread_name - - log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) - dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result: Tuple[int, int] = dag_file_processor.process_file( - file_path=file_path, - pickle_dags=pickle_dags, - callback_requests=callback_requests, - ) - result_channel.send(result) - log.info("Processing %s took %.3f seconds", file_path, timer.duration) - except Exception: # pylint: disable=broad-except - # Log exceptions through the logging framework. - log.exception("Got an exception! Propagating...") - raise - finally: - # We re-initialized the ORM within this Process above so we need to - # tear it down manually here - settings.dispose_orm() - - result_channel.close() - - def start(self) -> None: - """Launch the process and start processing the DAG.""" - start_method = self._get_multiprocessing_start_method() - context = multiprocessing.get_context(start_method) - - _parent_channel, _child_channel = context.Pipe(duplex=False) - process = context.Process( - target=type(self)._run_file_processor, - args=( - _child_channel, - _parent_channel, - self.file_path, - self._pickle_dags, - self._dag_ids, - f"DagFileProcessor{self._instance_id}", - self._callback_requests, - ), - name=f"DagFileProcessor{self._instance_id}-Process", - ) - self._process = process - self._start_time = timezone.utcnow() - process.start() - - # Close the child side of the pipe now the subprocess has started -- otherwise this would prevent it - # from closing in some cases - _child_channel.close() - del _child_channel - - # Don't store it on self until after we've started the child process - we don't want to keep it from - # getting GCd/closed - self._parent_channel = _parent_channel - - def kill(self) -> None: - """Kill the process launched to process the file, and ensure consistent state.""" - if self._process is None: - raise AirflowException("Tried to kill before starting!") - self._kill_process() - - def terminate(self, sigkill: bool = False) -> None: - """ - Terminate (and then kill) the process launched to process the file. - - :param sigkill: whether to issue a SIGKILL if SIGTERM doesn't work. - :type sigkill: bool - """ - if self._process is None or self._parent_channel is None: - raise AirflowException("Tried to call terminate before starting!") - - self._process.terminate() - # Arbitrarily wait 5s for the process to die - with suppress(TimeoutError): - self._process._popen.wait(5) # type: ignore # pylint: disable=protected-access - if sigkill: - self._kill_process() - self._parent_channel.close() - - def _kill_process(self) -> None: - if self._process is None: - raise AirflowException("Tried to kill process before starting!") - - if self._process.is_alive() and self._process.pid: - self.log.warning("Killing DAGFileProcessorProcess (PID=%d)", self._process.pid) - os.kill(self._process.pid, signal.SIGKILL) - if self._parent_channel: - self._parent_channel.close() - - @property - def pid(self) -> int: - """ - :return: the PID of the process launched to process the given file - :rtype: int - """ - if self._process is None or self._process.pid is None: - raise AirflowException("Tried to get PID before starting!") - return self._process.pid - - @property - def exit_code(self) -> Optional[int]: - """ - After the process is finished, this can be called to get the return code - - :return: the exit code of the process - :rtype: int - """ - if self._process is None: - raise AirflowException("Tried to get exit code before starting!") - if not self._done: - raise AirflowException("Tried to call retcode before process was finished!") - return self._process.exitcode - - @property - def done(self) -> bool: - """ - Check if the process launched to process this file is done. - - :return: whether the process is finished running - :rtype: bool - """ - if self._process is None or self._parent_channel is None: - raise AirflowException("Tried to see if it's done before starting!") - - if self._done: - return True - - if self._parent_channel.poll(): - try: - self._result = self._parent_channel.recv() - self._done = True - self.log.debug("Waiting for %s", self._process) - self._process.join() - self._parent_channel.close() - return True - except EOFError: - # If we get an EOFError, it means the child end of the pipe has been closed. This only happens - # in the finally block. But due to a possible race condition, the process may have not yet - # terminated (it could be doing cleanup/python shutdown still). So we kill it here after a - # "suitable" timeout. - self._done = True - # Arbitrary timeout -- error/race condition only, so this doesn't need to be tunable. - self._process.join(timeout=5) - if self._process.is_alive(): - # Didn't shut down cleanly - kill it - self._kill_process() - - if not self._process.is_alive(): - self._done = True - self.log.debug("Waiting for %s", self._process) - self._process.join() - self._parent_channel.close() - return True - - return False - - @property - def result(self) -> Optional[Tuple[int, int]]: - """ - :return: result of running DagFileProcessor.process_file() - :rtype: tuple[int, int] or None - """ - if not self.done: - raise AirflowException("Tried to get the result before it's done!") - return self._result - - @property - def start_time(self) -> datetime.datetime: - """ - :return: when this started to process the file - :rtype: datetime - """ - if self._start_time is None: - raise AirflowException("Tried to get start time before it started!") - return self._start_time - - @property - def waitable_handle(self): - return self._process.sentinel - - -class DagFileProcessor(LoggingMixin): - """ - Process a Python file containing Airflow DAGs. - - This includes: - - 1. Execute the file and look for DAG objects in the namespace. - 2. Execute any Callbacks if passed to DagFileProcessor.process_file - 3. Serialize the DAGs and save it to DB (or update existing record in the DB). - 4. Pickle the DAG and save it to the DB (if necessary). - 5. Record any errors importing the file into ORM - - Returns a tuple of 'number of dags found' and 'the count of import errors' - - :param dag_ids: If specified, only look at these DAG ID's - :type dag_ids: List[str] - :param log: Logger to save the processing process - :type log: logging.Logger - """ - - UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE') - - def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger): - super().__init__() - self.dag_ids = dag_ids - self._log = log - - @provide_session - def manage_slas(self, dag: DAG, session: Session = None) -> None: - """ - Finding all tasks that have SLAs defined, and sending alert emails - where needed. New SLA misses are also recorded in the database. - - We are assuming that the scheduler runs often, so we only check for - tasks that should have succeeded in the past hour. - """ - self.log.info("Running SLA Checks for %s", dag.dag_id) - if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): - self.log.info("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) - return - - qry = ( - session.query(TI.task_id, func.max(TI.execution_date).label('max_ti')) - .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql') - .filter(TI.dag_id == dag.dag_id) - .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)) - .filter(TI.task_id.in_(dag.task_ids)) - .group_by(TI.task_id) - .subquery('sq') - ) - - max_tis: List[TI] = ( - session.query(TI) - .filter( - TI.dag_id == dag.dag_id, - TI.task_id == qry.c.task_id, - TI.execution_date == qry.c.max_ti, - ) - .all() - ) - - ts = timezone.utcnow() - for ti in max_tis: - task = dag.get_task(ti.task_id) - if task.sla and not isinstance(task.sla, timedelta): - raise TypeError( - f"SLA is expected to be timedelta object, got " - f"{type(task.sla)} in {task.dag_id}:{task.task_id}" - ) - - dttm = dag.following_schedule(ti.execution_date) - while dttm < timezone.utcnow(): - following_schedule = dag.following_schedule(dttm) - if following_schedule + task.sla < timezone.utcnow(): - session.merge( - SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts) - ) - dttm = dag.following_schedule(dttm) - session.commit() - - # pylint: disable=singleton-comparison - slas: List[SlaMiss] = ( - session.query(SlaMiss) - .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa - .all() - ) - # pylint: enable=singleton-comparison - - if slas: # pylint: disable=too-many-nested-blocks - sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas] - fetched_tis: List[TI] = ( - session.query(TI) - .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) - .all() - ) - blocking_tis: List[TI] = [] - for ti in fetched_tis: - if ti.task_id in dag.task_ids: - ti.task = dag.get_task(ti.task_id) - blocking_tis.append(ti) - else: - session.delete(ti) - session.commit() - - task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) - blocking_task_list = "\n".join( - ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis - ) - # Track whether email or any alert notification sent - # We consider email or the alert callback as notifications - email_sent = False - notification_sent = False - if dag.sla_miss_callback: - # Execute the alert callback - self.log.info('Calling SLA miss callback') - try: - dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) - notification_sent = True - except Exception: # pylint: disable=broad-except - self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id) - email_content = f"""\ - Here's a list of tasks that missed their SLAs: - <pre><code>{task_list}\n<code></pre> - Blocking tasks: - <pre><code>{blocking_task_list}<code></pre> - Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} - """ - - tasks_missed_sla = [] - for sla in slas: - try: - task = dag.get_task(sla.task_id) - except TaskNotFound: - # task already deleted from DAG, skip it - self.log.warning( - "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id - ) - continue - tasks_missed_sla.append(task) - - emails: Set[str] = set() - for task in tasks_missed_sla: - if task.email: - if isinstance(task.email, str): - emails |= set(get_email_address_list(task.email)) - elif isinstance(task.email, (list, tuple)): - emails |= set(task.email) - if emails: - try: - send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) - email_sent = True - notification_sent = True - except Exception: # pylint: disable=broad-except - Stats.incr('sla_email_notification_failure') - self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id) - # If we sent any notification, update the sla_miss table - if notification_sent: - for sla in slas: - sla.email_sent = email_sent - sla.notification_sent = True - session.merge(sla) - session.commit() - - @staticmethod - def update_import_errors(session: Session, dagbag: DagBag) -> None: - """ - For the DAGs in the given DagBag, record any associated import errors and clears - errors for files that no longer have them. These are usually displayed through the - Airflow UI so that users know that there are issues parsing DAGs. - - :param session: session for ORM operations - :type session: sqlalchemy.orm.session.Session - :param dagbag: DagBag containing DAGs with import errors - :type dagbag: airflow.DagBag - """ - # Clear the errors of the processed files - for dagbag_file in dagbag.file_last_changed: - session.query(errors.ImportError).filter(errors.ImportError.filename == dagbag_file).delete() - - # Add the errors of the processed files - for filename, stacktrace in dagbag.import_errors.items(): - session.add( - errors.ImportError(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace) - ) - session.commit() - - @provide_session - def execute_callbacks( - self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = None - ) -> None: - """ - Execute on failure callbacks. These objects can come from SchedulerJob or from - DagFileProcessorManager. - - :param dagbag: Dag Bag of dags - :param callback_requests: failure callbacks to execute - :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] - :param session: DB session. - """ - for request in callback_requests: - self.log.debug("Processing Callback Request: %s", request) - try: - if isinstance(request, TaskCallbackRequest): - self._execute_task_callbacks(dagbag, request) - elif isinstance(request, SlaCallbackRequest): - self.manage_slas(dagbag.dags.get(request.dag_id)) - elif isinstance(request, DagCallbackRequest): - self._execute_dag_callbacks(dagbag, request, session) - except Exception: # pylint: disable=broad-except - self.log.exception( - "Error executing %s callback for file: %s", - request.__class__.__name__, - request.full_filepath, - ) - - session.commit() - - @provide_session - def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): - dag = dagbag.dags[request.dag_id] - dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) - dag.handle_callback( - dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session - ) - - def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): - simple_ti = request.simple_task_instance - if simple_ti.dag_id in dagbag.dags: - dag = dagbag.dags[simple_ti.dag_id] - if simple_ti.task_id in dag.task_ids: - task = dag.get_task(simple_ti.task_id) - ti = TI(task, simple_ti.execution_date) - # Get properties needed for failure handling from SimpleTaskInstance. - ti.start_date = simple_ti.start_date - ti.end_date = simple_ti.end_date - ti.try_number = simple_ti.try_number - ti.state = simple_ti.state - ti.test_mode = self.UNIT_TEST_MODE - if request.is_failure_callback: - ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode) - self.log.info('Executed failure callback for %s in state %s', ti, ti.state) - - @provide_session - def process_file( - self, - file_path: str, - callback_requests: List[CallbackRequest], - pickle_dags: bool = False, - session: Session = None, - ) -> Tuple[int, int]: - """ - Process a Python file containing Airflow DAGs. - - This includes: - - 1. Execute the file and look for DAG objects in the namespace. - 2. Execute any Callbacks if passed to this method. - 3. Serialize the DAGs and save it to DB (or update existing record in the DB). - 4. Pickle the DAG and save it to the DB (if necessary). - 5. Record any errors importing the file into ORM - - :param file_path: the path to the Python file that should be executed - :type file_path: str - :param callback_requests: failure callback to execute - :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest] - :param pickle_dags: whether serialize the DAGs found in the file and - save them to the db - :type pickle_dags: bool - :param session: Sqlalchemy ORM Session - :type session: Session - :return: number of dags found, count of import errors - :rtype: Tuple[int, int] - """ - self.log.info("Processing file %s for tasks to queue", file_path) - - try: - dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False) - except Exception: # pylint: disable=broad-except - self.log.exception("Failed at reloading the DAG file %s", file_path) - Stats.incr('dag_file_refresh_error', 1, 1) - return 0, 0 - - if len(dagbag.dags) > 0: - self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) - else: - self.log.warning("No viable dags retrieved from %s", file_path) - self.update_import_errors(session, dagbag) - return 0, len(dagbag.import_errors) - - self.execute_callbacks(dagbag, callback_requests) - - # Save individual DAGs in the ORM - dagbag.sync_to_db() - - if pickle_dags: - paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - - unpaused_dags: List[DAG] = [ - dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids - ] - - for dag in unpaused_dags: - dag.pickle(session) - - # Record import errors into the ORM - try: - self.update_import_errors(session, dagbag) - except Exception: # pylint: disable=broad-except - self.log.exception("Error logging import errors!") - - return len(dagbag.dags), len(dagbag.import_errors) - - def _is_parent_process(): """ Returns True if the current process is the parent process. False if the current process is a child diff --git a/tests/dag_processing/__init__.py b/tests/dag_processing/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/tests/dag_processing/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/utils/test_dag_processing.py b/tests/dag_processing/test_manager.py similarity index 99% rename from tests/utils/test_dag_processing.py rename to tests/dag_processing/test_manager.py index 58ad010..0ab7f2b 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/dag_processing/test_manager.py @@ -34,20 +34,20 @@ import pytest from freezegun import freeze_time from airflow.configuration import conf -from airflow.jobs.local_task_job import LocalTaskJob as LJ -from airflow.jobs.scheduler_job import DagFileProcessorProcess -from airflow.models import DagBag, DagModel, TaskInstance as TI -from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance -from airflow.utils import timezone -from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest -from airflow.utils.dag_processing import ( +from airflow.dag_processing.manager import ( DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, DagParsingSignal, DagParsingStat, ) +from airflow.dag_processing.processor import DagFileProcessorProcess +from airflow.jobs.local_task_job import LocalTaskJob as LJ +from airflow.models import DagBag, DagModel, TaskInstance as TI +from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.taskinstance import SimpleTaskInstance +from airflow.utils import timezone +from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest from airflow.utils.net import get_hostname from airflow.utils.session import create_session from airflow.utils.state import State diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py new file mode 100644 index 0000000..5953517 --- /dev/null +++ b/tests/dag_processing/test_processor.py @@ -0,0 +1,749 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# pylint: disable=attribute-defined-outside-init +import datetime +import os +import unittest +from datetime import timedelta +from tempfile import NamedTemporaryFile +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from parameterized import parameterized + +from airflow import settings +from airflow.configuration import conf +from airflow.dag_processing.processor import DagFileProcessor +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.models import DAG, DagBag, DagModel, SlaMiss, TaskInstance +from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel +from airflow.models.taskinstance import SimpleTaskInstance +from airflow.operators.bash import BashOperator +from airflow.operators.dummy import DummyOperator +from airflow.serialization.serialized_objects import SerializedDAG +from airflow.utils import timezone +from airflow.utils.callback_requests import TaskCallbackRequest +from airflow.utils.dates import days_ago +from airflow.utils.session import create_session +from airflow.utils.state import State +from airflow.utils.types import DagRunType +from tests.test_utils.config import conf_vars, env_vars +from tests.test_utils.db import ( + clear_db_dags, + clear_db_import_errors, + clear_db_jobs, + clear_db_pools, + clear_db_runs, + clear_db_serialized_dags, + clear_db_sla_miss, +) +from tests.test_utils.mock_executor import MockExecutor + +DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + [email protected](scope="class") +def disable_load_example(): + with conf_vars({('core', 'load_examples'): 'false'}): + with env_vars({('core', 'load_examples'): 'false'}): + yield + + [email protected]("disable_load_example") +class TestDagFileProcessor(unittest.TestCase): + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_pools() + clear_db_dags() + clear_db_sla_miss() + clear_db_import_errors() + clear_db_jobs() + clear_db_serialized_dags() + + def setUp(self): + self.clean_db() + + # Speed up some tests by not running the tasks, just look at what we + # enqueue! + self.null_exec = MockExecutor() + self.scheduler_job = None + + def tearDown(self) -> None: + if self.scheduler_job and self.scheduler_job.processor_agent: + self.scheduler_job.processor_agent.end() + self.scheduler_job = None + self.clean_db() + + def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs): + dag = DAG( + dag_id='test_scheduler_reschedule', + start_date=start_date, + # Make sure it only creates a single DAG Run + end_date=end_date, + ) + dag.clear() + dag.is_subdag = False + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) + session.merge(orm_dag) + session.commit() + return dag + + @classmethod + def setUpClass(cls): + # Ensure the DAGs we are looking at from the DB are up-to-date + non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False) + non_serialized_dagbag.sync_to_db() + cls.dagbag = DagBag(read_dags_from_db=True) + + def test_dag_file_processor_sla_miss_callback(self): + """ + Test that the dag file processor calls the sla miss callback + """ + session = settings.Session() + + sla_callback = MagicMock() + + # Create dag with a start of 1 day ago, but an sla of 0 + # so we'll already have an sla_miss on the books. + test_start_date = days_ago(1) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta()}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + + assert sla_callback.called + + def test_dag_file_processor_sla_miss_callback_invalid_sla(self): + """ + Test that the dag file processor does not call the sla miss callback when + given an invalid sla + """ + session = settings.Session() + + sla_callback = MagicMock() + + # Create dag with a start of 1 day ago, but an sla of 0 + # so we'll already have an sla_miss on the books. + # Pass anything besides a timedelta object to the sla argument. + test_start_date = days_ago(1) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': None}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + sla_callback.assert_not_called() + + def test_dag_file_processor_sla_miss_callback_sent_notification(self): + """ + Test that the dag file processor does not call the sla_miss_callback when a + notification has already been sent + """ + session = settings.Session() + + # Mock the callback function so we can verify that it was not called + sla_callback = MagicMock() + + # Create dag with a start of 2 days ago, but an sla of 1 day + # ago so we'll already have an sla_miss on the books + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + # Create a TaskInstance for two days ago + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge( + SlaMiss( + task_id='dummy', + dag_id='test_sla_miss', + execution_date=test_start_date, + email_sent=False, + notification_sent=True, + ) + ) + + # Now call manage_slas and see if the sla_miss callback gets called + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + dag_file_processor.manage_slas(dag=dag, session=session) + + sla_callback.assert_not_called() + + def test_dag_file_processor_sla_miss_callback_exception(self): + """ + Test that the dag file processor gracefully logs an exception if there is a problem + calling the sla_miss_callback + """ + session = settings.Session() + + sla_callback = MagicMock(side_effect=RuntimeError('Could not call function')) + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + sla_miss_callback=sla_callback, + default_args={'start_date': test_start_date}, + ) + + task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1)) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + # Now call manage_slas and see if the sla_miss callback gets called + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor.manage_slas(dag=dag, session=session) + assert sla_callback.called + mock_log.exception.assert_called_once_with( + 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss' + ) + + @mock.patch('airflow.dag_processing.processor.send_email') + def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email): + session = settings.Session() + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + email1 = '[email protected]' + task = DummyOperator( + task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + email2 = '[email protected]' + DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2) + + session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date)) + + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + + dag_file_processor.manage_slas(dag=dag, session=session) + + assert len(mock_send_email.call_args_list) == 1 + + send_email_to = mock_send_email.call_args_list[0][0][0] + assert email1 in send_email_to + assert email2 not in send_email_to + + @mock.patch('airflow.jobs.scheduler_job.Stats.incr') + @mock.patch("airflow.utils.email.send_email") + def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr): + """ + Test that the dag file processor gracefully logs an exception if there is a problem + sending an email + """ + session = settings.Session() + + # Mock the callback function so we can verify that it was not called + mock_send_email.side_effect = RuntimeError('Could not send an email') + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) + + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + + dag_file_processor.manage_slas(dag=dag, session=session) + mock_log.exception.assert_called_once_with( + 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss' + ) + mock_stats_incr.assert_called_once_with('sla_email_notification_failure') + + def test_dag_file_processor_sla_miss_deleted_task(self): + """ + Test that the dag file processor will not crash when trying to send + sla miss notification for a deleted task + """ + session = settings.Session() + + test_start_date = days_ago(2) + dag = DAG( + dag_id='test_sla_miss', + default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, + ) + + task = DummyOperator( + task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1) + ) + + session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) + + # Create an SlaMiss where notification was sent, but email was not + session.merge( + SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date) + ) + + mock_log = mock.MagicMock() + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) + dag_file_processor.manage_slas(dag=dag, session=session) + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances(self, state, start_date, end_date): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, set(), session) + assert count == 1 + + session.refresh(ti) + assert ti.state == State.SCHEDULED + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances_with_task_concurrency( + self, + state, + start_date, + end_date, + ): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, set(), session) + assert count == 1 + + session.refresh(ti) + assert ti.state == State.SCHEDULED + + @parameterized.expand( + [ + [State.NONE, None, None], + [ + State.UP_FOR_RETRY, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + [ + State.UP_FOR_RESCHEDULE, + timezone.utcnow() - datetime.timedelta(minutes=30), + timezone.utcnow() - datetime.timedelta(minutes=15), + ], + ] + ) + def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date): + """ + Test if _process_task_instances puts the right task instances into the + mock_list. + """ + dag = DAG( + dag_id='test_scheduler_process_execute_task_depends_on_past', + start_date=DEFAULT_DATE, + default_args={ + 'depends_on_past': True, + }, + ) + BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi') + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi') + + with create_session() as session: + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None + + with create_session() as session: + tis = dr.get_task_instances(session=session) + for ti in tis: + ti.state = state + ti.start_date = start_date + ti.end_date = end_date + + count = self.scheduler_job._schedule_dag_run(dr, set(), session) + assert count == 2 + + session.refresh(tis[0]) + session.refresh(tis[1]) + assert tis[0].state == State.SCHEDULED + assert tis[1].state == State.SCHEDULED + + def test_scheduler_job_add_new_task(self): + """ + Test if a task instance will be added if the dag is updated + """ + dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test') + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + + # Since we don't want to store the code for the DAG defined in this file + with mock.patch.object(settings, "STORE_DAG_CODE", False): + self.scheduler_job.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + if self.scheduler_job.processor_agent: + self.scheduler_job.processor_agent.end() + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session) + self.scheduler_job._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + tis = dr.get_task_instances() + assert len(tis) == 1 + + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') + SerializedDagModel.write_dag(dag=dag) + + scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session) + session.flush() + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + tis = dr.get_task_instances() + assert len(tis) == 2 + + def test_runs_respected_after_clear(self): + """ + Test if _process_task_instances only schedules ti's up to max_active_runs + (related to issue AIRFLOW-137) + """ + dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE) + dag.max_active_runs = 3 + + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi') + + session = settings.Session() + orm_dag = DagModel(dag_id=dag.dag_id) + session.merge(orm_dag) + session.commit() + session.close() + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.processor_agent = mock.MagicMock() + self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + dag.clear() + + date = DEFAULT_DATE + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + + # First create up to 3 dagruns in RUNNING state. + assert dr1 is not None + assert dr2 is not None + assert dr3 is not None + assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3 + + # Reduce max_active_runs to 1 + dag.max_active_runs = 1 + + # and schedule them in, so we can check how many + # tasks are put on the task_instances_list (should be one, not 3) + with create_session() as session: + num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session) + assert num_scheduled == 1 + num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session) + assert num_scheduled == 0 + num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session) + assert num_scheduled == 0 + + @patch.object(TaskInstance, 'handle_failure_with_callback') + def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): + dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + with create_session() as session: + session.query(TaskInstance).delete() + dag = dagbag.get_dag('example_branch_operator') + task = dag.get_task(task_id='run_this_first') + + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) + + session.add(ti) + session.commit() + + requests = [ + TaskCallbackRequest( + full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" + ) + ] + dag_file_processor.execute_callbacks(dagbag, requests) + mock_ti_handle_failure.assert_called_once_with( + error="Message", + test_mode=conf.getboolean('core', 'unit_test_mode'), + ) + + def test_process_file_should_failure_callback(self): + dag_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py' + ) + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + with create_session() as session, NamedTemporaryFile(delete=False) as callback_file: + session.query(TaskInstance).delete() + dag = dagbag.get_dag('test_om_failure_callback_dag') + task = dag.get_task(task_id='test_om_failure_callback_task') + + ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) + + session.add(ti) + session.commit() + + requests = [ + TaskCallbackRequest( + full_filepath=dag.full_filepath, + simple_task_instance=SimpleTaskInstance(ti), + msg="Message", + ) + ] + callback_file.close() + + with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}): + dag_file_processor.process_file(dag_file, requests) + with open(callback_file.name) as callback_file2: + content = callback_file2.read() + assert "Callback fired" == content + os.remove(callback_file.name) + + def test_should_mark_dummy_task_as_success(self): + dag_file = os.path.join( + os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' + ) + + # Write DAGs to dag and serialized_dag table + dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) + dagbag.sync_to_db() + + self.scheduler_job_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job_job.processor_agent = mock.MagicMock() + dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks") + + # Create DagRun + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + self.scheduler_job_job._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Schedule TaskInstances + self.scheduler_job_job._schedule_dag_run(dr, {}, session) + with create_session() as session: + tis = session.query(TaskInstance).all() + + dags = self.scheduler_job_job.dagbag.dags.values() + assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags] + assert 5 == len(tis) + assert { + ('test_task_a', 'success'), + ('test_task_b', None), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + } == {(ti.task_id, ti.state) for ti in tis} + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: + if state == 'success': + assert start_date is not None + assert end_date is not None + assert 0.0 == duration + else: + assert start_date is None + assert end_date is None + assert duration is None + + self.scheduler_job_job._schedule_dag_run(dr, {}, session) + with create_session() as session: + tis = session.query(TaskInstance).all() + + assert 5 == len(tis) + assert { + ('test_task_a', 'success'), + ('test_task_b', 'success'), + ('test_task_c', 'success'), + ('test_task_on_execute', 'scheduled'), + ('test_task_on_success', 'scheduled'), + } == {(ti.task_id, ti.state) for ti in tis} + for state, start_date, end_date, duration in [ + (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis + ]: + if state == 'success': + assert start_date is not None + assert end_date is not None + assert 0.0 == duration + else: + assert start_date is None + assert end_date is None + assert duration is None diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index fe0b257..9fe8517 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -22,7 +22,7 @@ import os import shutil import unittest from datetime import timedelta -from tempfile import NamedTemporaryFile, mkdtemp +from tempfile import mkdtemp from time import sleep from unittest import mock from unittest.mock import MagicMock, patch @@ -37,22 +37,20 @@ from sqlalchemy import func import airflow.example_dags import airflow.smart_sensor_dags from airflow import settings -from airflow.configuration import conf +from airflow.dag_processing.manager import DagFileProcessorAgent from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor from airflow.jobs.backfill_job import BackfillJob -from airflow.jobs.scheduler_job import DagFileProcessor, SchedulerJob -from airflow.models import DAG, DagBag, DagModel, Pool, SlaMiss, TaskInstance, errors +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.models import DAG, DagBag, DagModel, Pool, TaskInstance, errors from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey +from airflow.models.taskinstance import TaskInstanceKey from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone -from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest -from airflow.utils.dag_processing import DagFileProcessorAgent -from airflow.utils.dates import days_ago +from airflow.utils.callback_requests import DagCallbackRequest from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session from airflow.utils.state import State @@ -101,688 +99,6 @@ def disable_load_example(): @pytest.mark.usefixtures("disable_load_example") -class TestDagFileProcessor(unittest.TestCase): - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_pools() - clear_db_dags() - clear_db_sla_miss() - clear_db_import_errors() - clear_db_jobs() - clear_db_serialized_dags() - - def setUp(self): - self.clean_db() - - # Speed up some tests by not running the tasks, just look at what we - # enqueue! - self.null_exec = MockExecutor() - self.scheduler_job = None - - def tearDown(self) -> None: - if self.scheduler_job and self.scheduler_job.processor_agent: - self.scheduler_job.processor_agent.end() - self.scheduler_job = None - self.clean_db() - - def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timedelta(hours=1), **kwargs): - dag = DAG( - dag_id='test_scheduler_reschedule', - start_date=start_date, - # Make sure it only creates a single DAG Run - end_date=end_date, - ) - dag.clear() - dag.is_subdag = False - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) - session.merge(orm_dag) - session.commit() - return dag - - @classmethod - def setUpClass(cls): - # Ensure the DAGs we are looking at from the DB are up-to-date - non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False) - non_serialized_dagbag.sync_to_db() - cls.dagbag = DagBag(read_dags_from_db=True) - - def test_dag_file_processor_sla_miss_callback(self): - """ - Test that the dag file processor calls the sla miss callback - """ - session = settings.Session() - - sla_callback = MagicMock() - - # Create dag with a start of 1 day ago, but an sla of 0 - # so we'll already have an sla_miss on the books. - test_start_date = days_ago(1) - dag = DAG( - dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, 'sla': datetime.timedelta()}, - ) - - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) - - session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_file_processor.manage_slas(dag=dag, session=session) - - assert sla_callback.called - - def test_dag_file_processor_sla_miss_callback_invalid_sla(self): - """ - Test that the dag file processor does not call the sla miss callback when - given an invalid sla - """ - session = settings.Session() - - sla_callback = MagicMock() - - # Create dag with a start of 1 day ago, but an sla of 0 - # so we'll already have an sla_miss on the books. - # Pass anything besides a timedelta object to the sla argument. - test_start_date = days_ago(1) - dag = DAG( - dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, 'sla': None}, - ) - - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) - - session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_file_processor.manage_slas(dag=dag, session=session) - sla_callback.assert_not_called() - - def test_dag_file_processor_sla_miss_callback_sent_notification(self): - """ - Test that the dag file processor does not call the sla_miss_callback when a - notification has already been sent - """ - session = settings.Session() - - # Mock the callback function so we can verify that it was not called - sla_callback = MagicMock() - - # Create dag with a start of 2 days ago, but an sla of 1 day - # ago so we'll already have an sla_miss on the books - test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow') - - # Create a TaskInstance for two days ago - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='success')) - - # Create an SlaMiss where notification was sent, but email was not - session.merge( - SlaMiss( - task_id='dummy', - dag_id='test_sla_miss', - execution_date=test_start_date, - email_sent=False, - notification_sent=True, - ) - ) - - # Now call manage_slas and see if the sla_miss callback gets called - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_file_processor.manage_slas(dag=dag, session=session) - - sla_callback.assert_not_called() - - def test_dag_file_processor_sla_miss_callback_exception(self): - """ - Test that the dag file processor gracefully logs an exception if there is a problem - calling the sla_miss_callback - """ - session = settings.Session() - - sla_callback = MagicMock(side_effect=RuntimeError('Could not call function')) - - test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - sla_miss_callback=sla_callback, - default_args={'start_date': test_start_date}, - ) - - task = DummyOperator(task_id='dummy', dag=dag, owner='airflow', sla=datetime.timedelta(hours=1)) - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) - - # Create an SlaMiss where notification was sent, but email was not - session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - - # Now call manage_slas and see if the sla_miss callback gets called - mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) - dag_file_processor.manage_slas(dag=dag, session=session) - assert sla_callback.called - mock_log.exception.assert_called_once_with( - 'Could not call sla_miss_callback for DAG %s', 'test_sla_miss' - ) - - @mock.patch('airflow.jobs.scheduler_job.send_email') - def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(self, mock_send_email): - session = settings.Session() - - test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - email1 = '[email protected]' - task = DummyOperator( - task_id='sla_missed', dag=dag, owner='airflow', email=email1, sla=datetime.timedelta(hours=1) - ) - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) - - email2 = '[email protected]' - DummyOperator(task_id='sla_not_missed', dag=dag, owner='airflow', email=email2) - - session.merge(SlaMiss(task_id='sla_missed', dag_id='test_sla_miss', execution_date=test_start_date)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - - dag_file_processor.manage_slas(dag=dag, session=session) - - assert len(mock_send_email.call_args_list) == 1 - - send_email_to = mock_send_email.call_args_list[0][0][0] - assert email1 in send_email_to - assert email2 not in send_email_to - - @mock.patch('airflow.jobs.scheduler_job.Stats.incr') - @mock.patch("airflow.utils.email.send_email") - def test_dag_file_processor_sla_miss_email_exception(self, mock_send_email, mock_stats_incr): - """ - Test that the dag file processor gracefully logs an exception if there is a problem - sending an email - """ - session = settings.Session() - - # Mock the callback function so we can verify that it was not called - mock_send_email.side_effect = RuntimeError('Could not send an email') - - test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - task = DummyOperator( - task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1) - ) - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) - - # Create an SlaMiss where notification was sent, but email was not - session.merge(SlaMiss(task_id='dummy', dag_id='test_sla_miss', execution_date=test_start_date)) - - mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) - - dag_file_processor.manage_slas(dag=dag, session=session) - mock_log.exception.assert_called_once_with( - 'Could not send SLA Miss email notification for DAG %s', 'test_sla_miss' - ) - mock_stats_incr.assert_called_once_with('sla_email_notification_failure') - - def test_dag_file_processor_sla_miss_deleted_task(self): - """ - Test that the dag file processor will not crash when trying to send - sla miss notification for a deleted task - """ - session = settings.Session() - - test_start_date = days_ago(2) - dag = DAG( - dag_id='test_sla_miss', - default_args={'start_date': test_start_date, 'sla': datetime.timedelta(days=1)}, - ) - - task = DummyOperator( - task_id='dummy', dag=dag, owner='airflow', email='[email protected]', sla=datetime.timedelta(hours=1) - ) - - session.merge(TaskInstance(task=task, execution_date=test_start_date, state='Success')) - - # Create an SlaMiss where notification was sent, but email was not - session.merge( - SlaMiss(task_id='dummy_deleted', dag_id='test_sla_miss', execution_date=test_start_date) - ) - - mock_log = mock.MagicMock() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) - dag_file_processor.manage_slas(dag=dag, session=session) - - @parameterized.expand( - [ - [State.NONE, None, None], - [ - State.UP_FOR_RETRY, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - [ - State.UP_FOR_RESCHEDULE, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - ] - ) - def test_dag_file_processor_process_task_instances(self, state, start_date, end_date): - """ - Test if _process_task_instances puts the right task instances into the - mock_list. - """ - dag = DAG(dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) - BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - dag.clear() - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - assert dr is not None - - with create_session() as session: - ti = dr.get_task_instances(session=session)[0] - ti.state = state - ti.start_date = start_date - ti.end_date = end_date - - count = self.scheduler_job._schedule_dag_run(dr, set(), session) - assert count == 1 - - session.refresh(ti) - assert ti.state == State.SCHEDULED - - @parameterized.expand( - [ - [State.NONE, None, None], - [ - State.UP_FOR_RETRY, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - [ - State.UP_FOR_RESCHEDULE, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - ] - ) - def test_dag_file_processor_process_task_instances_with_task_concurrency( - self, - state, - start_date, - end_date, - ): - """ - Test if _process_task_instances puts the right task instances into the - mock_list. - """ - dag = DAG(dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) - BashOperator(task_id='dummy', task_concurrency=2, dag=dag, owner='airflow', bash_command='echo Hi') - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - dag.clear() - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - assert dr is not None - - with create_session() as session: - ti = dr.get_task_instances(session=session)[0] - ti.state = state - ti.start_date = start_date - ti.end_date = end_date - - count = self.scheduler_job._schedule_dag_run(dr, set(), session) - assert count == 1 - - session.refresh(ti) - assert ti.state == State.SCHEDULED - - @parameterized.expand( - [ - [State.NONE, None, None], - [ - State.UP_FOR_RETRY, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - [ - State.UP_FOR_RESCHEDULE, - timezone.utcnow() - datetime.timedelta(minutes=30), - timezone.utcnow() - datetime.timedelta(minutes=15), - ], - ] - ) - def test_dag_file_processor_process_task_instances_depends_on_past(self, state, start_date, end_date): - """ - Test if _process_task_instances puts the right task instances into the - mock_list. - """ - dag = DAG( - dag_id='test_scheduler_process_execute_task_depends_on_past', - start_date=DEFAULT_DATE, - default_args={ - 'depends_on_past': True, - }, - ) - BashOperator(task_id='dummy1', dag=dag, owner='airflow', bash_command='echo hi') - BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo hi') - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - dag.clear() - dr = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=DEFAULT_DATE, - state=State.RUNNING, - ) - assert dr is not None - - with create_session() as session: - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = state - ti.start_date = start_date - ti.end_date = end_date - - count = self.scheduler_job._schedule_dag_run(dr, set(), session) - assert count == 2 - - session.refresh(tis[0]) - session.refresh(tis[1]) - assert tis[0].state == State.SCHEDULED - assert tis[1].state == State.SCHEDULED - - def test_scheduler_job_add_new_task(self): - """ - Test if a task instance will be added if the dag is updated - """ - dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) - BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test') - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - - # Since we don't want to store the code for the DAG defined in this file - with mock.patch.object(settings, "STORE_DAG_CODE", False): - self.scheduler_job.dagbag.sync_to_db() - - session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - assert orm_dag is not None - - if self.scheduler_job.processor_agent: - self.scheduler_job.processor_agent.end() - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.processor_agent = mock.MagicMock() - dag = self.scheduler_job.dagbag.get_dag('test_scheduler_add_new_task', session=session) - self.scheduler_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] - - tis = dr.get_task_instances() - assert len(tis) == 1 - - BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') - SerializedDagModel.write_dag(dag=dag) - - scheduled_tis = self.scheduler_job._schedule_dag_run(dr, set(), session) - session.flush() - assert scheduled_tis == 2 - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] - - tis = dr.get_task_instances() - assert len(tis) == 2 - - def test_runs_respected_after_clear(self): - """ - Test if _process_task_instances only schedules ti's up to max_active_runs - (related to issue AIRFLOW-137) - """ - dag = DAG(dag_id='test_scheduler_max_active_runs_respected_after_clear', start_date=DEFAULT_DATE) - dag.max_active_runs = 3 - - BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo Hi') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - self.scheduler_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job.processor_agent = mock.MagicMock() - self.scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - dag.clear() - - date = DEFAULT_DATE - dr1 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) - date = dag.following_schedule(date) - dr2 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) - date = dag.following_schedule(date) - dr3 = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=date, - state=State.RUNNING, - ) - - # First create up to 3 dagruns in RUNNING state. - assert dr1 is not None - assert dr2 is not None - assert dr3 is not None - assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3 - - # Reduce max_active_runs to 1 - dag.max_active_runs = 1 - - # and schedule them in, so we can check how many - # tasks are put on the task_instances_list (should be one, not 3) - with create_session() as session: - num_scheduled = self.scheduler_job._schedule_dag_run(dr1, set(), session) - assert num_scheduled == 1 - num_scheduled = self.scheduler_job._schedule_dag_run(dr2, {dr1.execution_date}, session) - assert num_scheduled == 0 - num_scheduled = self.scheduler_job._schedule_dag_run(dr3, {dr1.execution_date}, session) - assert num_scheduled == 0 - - @patch.object(TaskInstance, 'handle_failure_with_callback') - def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): - dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - with create_session() as session: - session.query(TaskInstance).delete() - dag = dagbag.get_dag('example_branch_operator') - task = dag.get_task(task_id='run_this_first') - - ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) - - session.add(ti) - session.commit() - - requests = [ - TaskCallbackRequest( - full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" - ) - ] - dag_file_processor.execute_callbacks(dagbag, requests) - mock_ti_handle_failure.assert_called_once_with( - error="Message", - test_mode=conf.getboolean('core', 'unit_test_mode'), - ) - - def test_process_file_should_failure_callback(self): - dag_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py' - ) - dagbag = DagBag(dag_folder=dag_file, include_examples=False) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - with create_session() as session, NamedTemporaryFile(delete=False) as callback_file: - session.query(TaskInstance).delete() - dag = dagbag.get_dag('test_om_failure_callback_dag') - task = dag.get_task(task_id='test_om_failure_callback_task') - - ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING) - - session.add(ti) - session.commit() - - requests = [ - TaskCallbackRequest( - full_filepath=dag.full_filepath, - simple_task_instance=SimpleTaskInstance(ti), - msg="Message", - ) - ] - callback_file.close() - - with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}): - dag_file_processor.process_file(dag_file, requests) - with open(callback_file.name) as callback_file2: - content = callback_file2.read() - assert "Callback fired" == content - os.remove(callback_file.name) - - def test_should_mark_dummy_task_as_success(self): - dag_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' - ) - - # Write DAGs to dag and serialized_dag table - dagbag = DagBag(dag_folder=dag_file, include_examples=False, read_dags_from_db=False) - dagbag.sync_to_db() - - self.scheduler_job_job = SchedulerJob(subdir=os.devnull) - self.scheduler_job_job.processor_agent = mock.MagicMock() - dag = self.scheduler_job_job.dagbag.get_dag("test_only_dummy_tasks") - - # Create DagRun - session = settings.Session() - orm_dag = session.query(DagModel).get(dag.dag_id) - self.scheduler_job_job._create_dag_runs([orm_dag], session) - - drs = DagRun.find(dag_id=dag.dag_id, session=session) - assert len(drs) == 1 - dr = drs[0] - - # Schedule TaskInstances - self.scheduler_job_job._schedule_dag_run(dr, {}, session) - with create_session() as session: - tis = session.query(TaskInstance).all() - - dags = self.scheduler_job_job.dagbag.dags.values() - assert ['test_only_dummy_tasks'] == [dag.dag_id for dag in dags] - assert 5 == len(tis) - assert { - ('test_task_a', 'success'), - ('test_task_b', None), - ('test_task_c', 'success'), - ('test_task_on_execute', 'scheduled'), - ('test_task_on_success', 'scheduled'), - } == {(ti.task_id, ti.state) for ti in tis} - for state, start_date, end_date, duration in [ - (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis - ]: - if state == 'success': - assert start_date is not None - assert end_date is not None - assert 0.0 == duration - else: - assert start_date is None - assert end_date is None - assert duration is None - - self.scheduler_job_job._schedule_dag_run(dr, {}, session) - with create_session() as session: - tis = session.query(TaskInstance).all() - - assert 5 == len(tis) - assert { - ('test_task_a', 'success'), - ('test_task_b', 'success'), - ('test_task_c', 'success'), - ('test_task_on_execute', 'scheduled'), - ('test_task_on_success', 'scheduled'), - } == {(ti.task_id, ti.state) for ti in tis} - for state, start_date, end_date, duration in [ - (ti.state, ti.start_date, ti.end_date, ti.duration) for ti in tis - ]: - if state == 'success': - assert start_date is not None - assert end_date is not None - assert 0.0 == duration - else: - assert start_date is None - assert end_date is None - assert duration is None - - [email protected]("disable_load_example") class TestSchedulerJob(unittest.TestCase): @staticmethod def clean_db(): @@ -802,7 +118,7 @@ class TestSchedulerJob(unittest.TestCase): # enqueue! self.null_exec = MockExecutor() - self.patcher = patch('airflow.utils.dag_processing.SerializedDagModel.remove_deleted_dags') + self.patcher = patch('airflow.dag_processing.manager.SerializedDagModel.remove_deleted_dags') # Since we don't want to store the code for the DAG defined in this file self.patcher_dag_code = patch.object(settings, "STORE_DAG_CODE", False) self.patcher.start() @@ -3213,7 +2529,7 @@ class TestSchedulerJob(unittest.TestCase): dagbag.bag_dag(dag=dag, root_dag=dag) dagbag.sync_to_db() - @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) + @mock.patch('airflow.dag_processing.processor.DagBag', return_value=dagbag) def do_schedule(mock_dagbag): # Use a empty file since the above mock will return the # expected DAGs. Also specify only a single file so that it doesn't diff --git a/tests/test_utils/perf/perf_kit/python.py b/tests/test_utils/perf/perf_kit/python.py index 7d92a49..596f4f6 100644 --- a/tests/test_utils/perf/perf_kit/python.py +++ b/tests/test_utils/perf/perf_kit/python.py @@ -91,7 +91,7 @@ if __name__ == "__main__": import logging import airflow - from airflow.jobs.scheduler_job import DagFileProcessor + from airflow.dag_processing.processor import DagFileProcessor log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py index e60ad51..37cf0fe 100644 --- a/tests/test_utils/perf/perf_kit/sqlalchemy.py +++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py @@ -218,7 +218,7 @@ if __name__ == "__main__": import logging from unittest import mock - from airflow.jobs.scheduler_job import DagFileProcessor + from airflow.dag_processing.processor import DagFileProcessor with mock.patch.dict( "os.environ",
