This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0359a42a39 Migrate DagFileProcessor.manage_slas to Internal API 
(#28502)
0359a42a39 is described below

commit 0359a42a3975d0d7891a39abe4395bdd6f210718
Author: Vincent <[email protected]>
AuthorDate: Mon Jan 23 15:54:25 2023 -0500

    Migrate DagFileProcessor.manage_slas to Internal API (#28502)
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |   1 +
 airflow/dag_processing/processor.py                |  33 ++++--
 airflow/utils/log/logging_mixin.py                 |  20 +++-
 tests/dag_processing/test_processor.py             | 121 +++++++++++++--------
 4 files changed, 114 insertions(+), 61 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 1a12ae5bf2..eb7a0fbfa3 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -39,6 +39,7 @@ def _initialize_map() -> dict[str, Callable]:
 
     functions: list[Callable] = [
         DagFileProcessor.update_import_errors,
+        DagFileProcessor.manage_slas,
         DagModel.get_paused_dag_ids,
         DagFileProcessorManager.clear_nonexistent_import_errors,
         XCom.get_value,
diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index 5c9b7ebfd6..a50ca933dc 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -365,8 +365,10 @@ class DagFileProcessor(LoggingMixin):
         self._dag_directory = dag_directory
         self.dag_warnings: set[tuple[str, str]] = set()
 
+    @classmethod
+    @internal_api_call
     @provide_session
-    def manage_slas(self, dag: DAG, session: Session = None) -> None:
+    def manage_slas(cls, dag_folder, dag_id: str, session: Session = 
NEW_SESSION) -> None:
         """
         Finding all tasks that have SLAs defined, and sending alert emails 
when needed.
 
@@ -375,9 +377,11 @@ class DagFileProcessor(LoggingMixin):
         We are assuming that the scheduler runs often, so we only check for
         tasks that should have succeeded in the past hour.
         """
-        self.log.info("Running SLA Checks for %s", dag.dag_id)
+        dagbag = DagFileProcessor._get_dagbag(dag_folder)
+        dag = dagbag.get_dag(dag_id)
+        cls.logger().info("Running SLA Checks for %s", dag.dag_id)
         if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
-            self.log.info("Skipping SLA check for %s because no tasks in DAG 
have SLAs", dag)
+            cls.logger().info("Skipping SLA check for %s because no tasks in 
DAG have SLAs", dag)
             return
 
         qry = (
@@ -481,7 +485,7 @@ class DagFileProcessor(LoggingMixin):
                     else [dag.sla_miss_callback]
                 )
                 for callback in callbacks:
-                    self.log.info("Calling SLA miss callback %s", callback)
+                    cls.logger().info("Calling SLA miss callback %s", callback)
                     try:
                         callback(dag, task_list, blocking_task_list, slas, 
blocking_tis)
                         notification_sent = True
@@ -493,7 +497,7 @@ class DagFileProcessor(LoggingMixin):
                                 "func_name": callback.__name__,
                             },
                         )
-                        self.log.exception(
+                        cls.logger().exception(
                             "Could not call sla_miss_callback(%s) for DAG %s",
                             callback.__name__,
                             dag.dag_id,
@@ -512,7 +516,7 @@ class DagFileProcessor(LoggingMixin):
                     task = dag.get_task(sla.task_id)
                 except TaskNotFound:
                     # task already deleted from DAG, skip it
-                    self.log.warning(
+                    cls.logger().warning(
                         "Task %s doesn't exist in DAG anymore, skipping SLA 
miss notification.", sla.task_id
                     )
                     continue
@@ -532,7 +536,9 @@ class DagFileProcessor(LoggingMixin):
                     notification_sent = True
                 except Exception:
                     Stats.incr("sla_email_notification_failure", 
tags={"dag_id": dag.dag_id})
-                    self.log.exception("Could not send SLA Miss email 
notification for DAG %s", dag.dag_id)
+                    cls.logger().exception(
+                        "Could not send SLA Miss email notification for DAG 
%s", dag.dag_id
+                    )
             # If we sent any notification, update the sla_miss table
             if notification_sent:
                 for sla in slas:
@@ -652,7 +658,7 @@ class DagFileProcessor(LoggingMixin):
                 if isinstance(request, TaskCallbackRequest):
                     self._execute_task_callbacks(dagbag, request, 
session=session)
                 elif isinstance(request, SlaCallbackRequest):
-                    self.manage_slas(dagbag.get_dag(request.dag_id), 
session=session)
+                    DagFileProcessor.manage_slas(dagbag.dag_folder, 
request.dag_id, session=session)
                 elif isinstance(request, DagCallbackRequest):
                     self._execute_dag_callbacks(dagbag, request, session)
             except Exception:
@@ -736,6 +742,15 @@ class DagFileProcessor(LoggingMixin):
         self.log.info("Executed failure callback for %s in state %s", ti, 
ti.state)
         session.flush()
 
+    @classmethod
+    def _get_dagbag(cls, file_path: str):
+        try:
+            return DagBag(file_path, include_examples=False)
+        except Exception:
+            cls.logger().exception("Failed at reloading the DAG file %s", 
file_path)
+            Stats.incr("dag_file_refresh_error", 1, 1)
+            raise
+
     @provide_session
     def process_file(
         self,
@@ -766,7 +781,7 @@ class DagFileProcessor(LoggingMixin):
         self.log.info("Processing file %s for tasks to queue", file_path)
 
         try:
-            dagbag = DagBag(file_path, include_examples=False)
+            dagbag = DagFileProcessor._get_dagbag(file_path)
         except Exception:
             self.log.exception("Failed at reloading the DAG file %s", 
file_path)
             Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": 
file_path})
diff --git a/airflow/utils/log/logging_mixin.py 
b/airflow/utils/log/logging_mixin.py
index 85ff71a94f..79746c8492 100644
--- a/airflow/utils/log/logging_mixin.py
+++ b/airflow/utils/log/logging_mixin.py
@@ -24,7 +24,7 @@ import re
 import sys
 from io import IOBase
 from logging import Handler, Logger, StreamHandler
-from typing import IO, cast
+from typing import IO, Any, TypeVar, cast
 
 from airflow.settings import IS_K8S_EXECUTOR_POD
 
@@ -59,6 +59,9 @@ def remove_escape_codes(text: str) -> str:
     return ANSI_ESCAPE.sub("", text)
 
 
+_T = TypeVar("_T")
+
+
 class LoggingMixin:
     """Convenience super-class to have a logger configured with the class 
name"""
 
@@ -67,12 +70,21 @@ class LoggingMixin:
     def __init__(self, context=None):
         self._set_context(context)
 
+    @staticmethod
+    def _get_log(obj: Any, clazz: type[_T]) -> Logger:
+        if obj._log is None:
+            obj._log = 
logging.getLogger(f"{clazz.__module__}.{clazz.__name__}")
+        return obj._log
+
+    @classmethod
+    def logger(cls) -> Logger:
+        """Returns a logger."""
+        return LoggingMixin._get_log(cls, cls)
+
     @property
     def log(self) -> Logger:
         """Returns a logger."""
-        if self._log is None:
-            self._log = logging.getLogger(self.__class__.__module__ + "." + 
self.__class__.__name__)
-        return self._log
+        return LoggingMixin._get_log(self, self.__class__)
 
     def _set_context(self, context):
         if context is not None:
diff --git a/tests/dag_processing/test_processor.py 
b/tests/dag_processing/test_processor.py
index 8bad4c8cb8..12528ae462 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -105,16 +105,15 @@ class TestDagFileProcessor:
 
         dag_file_processor.process_file(file_path, [], False, session)
 
-    def test_dag_file_processor_sla_miss_callback(self, create_dummy_dag):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_callback(self, mock_get_dagbag, 
create_dummy_dag, get_test_dag):
         """
         Test that the dag file processor calls the sla miss callback
         """
         session = settings.Session()
-
         sla_callback = MagicMock()
 
-        # Create dag with a start of 1 day ago, but an sla of 0
-        # so we'll already have an sla_miss on the books.
+        # Create dag with a start of 1 day ago, but a sla of 0, so we'll 
already have a sla_miss on the books.
         test_start_date = timezone.utcnow() - datetime.timedelta(days=1)
         dag, task = create_dummy_dag(
             dag_id="test_sla_miss",
@@ -124,17 +123,18 @@ class TestDagFileProcessor:
         )
 
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state="success"))
-
         session.merge(SlaMiss(task_id="dummy", dag_id="test_sla_miss", 
execution_date=test_start_date))
 
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
 
         assert sla_callback.called
 
-    def test_dag_file_processor_sla_miss_callback_invalid_sla(self, 
create_dummy_dag):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_callback_invalid_sla(self, 
mock_get_dagbag, create_dummy_dag):
         """
         Test that the dag file processor does not call the sla miss callback 
when
         given an invalid sla
@@ -155,16 +155,17 @@ class TestDagFileProcessor:
         )
 
         session.merge(TaskInstance(task=task, execution_date=test_start_date, 
state="success"))
-
         session.merge(SlaMiss(task_id="dummy", dag_id="test_sla_miss", 
execution_date=test_start_date))
 
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
         sla_callback.assert_not_called()
 
-    def test_dag_file_processor_sla_miss_callback_sent_notification(self, 
create_dummy_dag):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_callback_sent_notification(self, 
mock_get_dagbag, create_dummy_dag):
         """
         Test that the dag file processor does not call the sla_miss_callback 
when a
         notification has already been sent
@@ -198,16 +199,20 @@ class TestDagFileProcessor:
             )
         )
 
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
         # Now call manage_slas and see if the sla_miss callback gets called
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
 
         sla_callback.assert_not_called()
 
     @mock.patch("airflow.dag_processing.processor.Stats.incr")
-    def test_dag_file_processor_sla_miss_doesnot_raise_integrity_error(self, 
mock_stats_incr, dag_maker):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_doesnot_raise_integrity_error(
+        self, mock_get_dagbag, mock_stats_incr, dag_maker
+    ):
         """
         Test that the dag file processor does not try to insert already 
existing item into the database
         """
@@ -229,10 +234,11 @@ class TestDagFileProcessor:
         session.merge(ti)
         session.flush()
 
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
         sla_miss_count = (
             session.query(SlaMiss)
             .filter(
@@ -249,11 +255,12 @@ class TestDagFileProcessor:
         # because of existing SlaMiss above.
         # Since this is run often, it's possible that it runs before another
         # ti is successful thereby trying to insert a duplicate record.
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
 
     @mock.patch("airflow.dag_processing.processor.Stats.incr")
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
     def 
test_dag_file_processor_sla_miss_continue_checking_the_task_instances_after_recording_missing_sla(
-        self, mock_stats_incr, dag_maker
+        self, mock_get_dagbag, mock_stats_incr, dag_maker
     ):
         """
         Test that the dag file processor continue checking subsequent task 
instances
@@ -279,10 +286,11 @@ class TestDagFileProcessor:
         )
         session.flush()
 
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
         sla_miss_count = (
             session.query(SlaMiss)
             .filter(
@@ -296,8 +304,12 @@ class TestDagFileProcessor:
             "sla_missed", tags={"dag_id": "test_sla_miss", "run_id": "test", 
"task_id": "dummy"}
         )
 
+    @patch.object(DagFileProcessor, "logger")
     @mock.patch("airflow.dag_processing.processor.Stats.incr")
-    def test_dag_file_processor_sla_miss_callback_exception(self, 
mock_stats_incr, create_dummy_dag):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_callback_exception(
+        self, mock_get_dagbag, mock_stats_incr, mock_get_log, create_dummy_dag
+    ):
         """
         Test that the dag file processor gracefully logs an exception if there 
is a problem
         calling the sla_miss_callback
@@ -327,9 +339,13 @@ class TestDagFileProcessor:
             )
 
             # Now call manage_slas and see if the sla_miss callback gets called
-            mock_log = mock.MagicMock()
-            dag_file_processor = DagFileProcessor(dag_ids=[], 
dag_directory=TEST_DAGS_FOLDER, log=mock_log)
-            dag_file_processor.manage_slas(dag=dag, session=session)
+            mock_log = mock.Mock()
+            mock_get_log.return_value = mock_log
+            mock_dagbag = mock.Mock()
+            mock_dagbag.get_dag.return_value = dag
+            mock_get_dagbag.return_value = mock_dagbag
+
+            DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
             assert sla_callback.called
             mock_log.exception.assert_called_once_with(
                 "Could not call sla_miss_callback(%s) for DAG %s",
@@ -342,8 +358,9 @@ class TestDagFileProcessor:
             )
 
     @mock.patch("airflow.dag_processing.processor.send_email")
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
     def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(
-        self, mock_send_email, create_dummy_dag
+        self, mock_get_dagbag, mock_send_email, create_dummy_dag
     ):
         session = settings.Session()
 
@@ -363,11 +380,11 @@ class TestDagFileProcessor:
 
         session.merge(SlaMiss(task_id="sla_missed", dag_id="test_sla_miss", 
execution_date=test_start_date))
 
-        dag_file_processor = DagFileProcessor(
-            dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock.MagicMock()
-        )
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
 
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
 
         assert len(mock_send_email.call_args_list) == 1
 
@@ -375,10 +392,12 @@ class TestDagFileProcessor:
         assert email1 in send_email_to
         assert email2 not in send_email_to
 
+    @patch.object(DagFileProcessor, "logger")
     @mock.patch("airflow.dag_processing.processor.Stats.incr")
     @mock.patch("airflow.utils.email.send_email")
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
     def test_dag_file_processor_sla_miss_email_exception(
-        self, mock_send_email, mock_stats_incr, create_dummy_dag
+        self, mock_get_dagbag, mock_send_email, mock_stats_incr, mock_get_log, 
create_dummy_dag
     ):
         """
         Test that the dag file processor gracefully logs an exception if there 
is a problem
@@ -403,10 +422,13 @@ class TestDagFileProcessor:
         # Create an SlaMiss where notification was sent, but email was not
         session.merge(SlaMiss(task_id="dummy", dag_id="test_sla_miss", 
execution_date=test_start_date))
 
-        mock_log = mock.MagicMock()
-        dag_file_processor = DagFileProcessor(dag_ids=[], 
dag_directory=TEST_DAGS_FOLDER, log=mock_log)
+        mock_log = mock.Mock()
+        mock_get_log.return_value = mock_log
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
 
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
         mock_log.exception.assert_called_once_with(
             "Could not send SLA Miss email notification for DAG %s", 
"test_sla_miss"
         )
@@ -414,7 +436,8 @@ class TestDagFileProcessor:
             "sla_email_notification_failure", tags={"dag_id": "test_sla_miss"}
         )
 
-    def test_dag_file_processor_sla_miss_deleted_task(self, create_dummy_dag):
+    
@mock.patch("airflow.dag_processing.processor.DagFileProcessor._get_dagbag")
+    def test_dag_file_processor_sla_miss_deleted_task(self, mock_get_dagbag, 
create_dummy_dag):
         """
         Test that the dag file processor will not crash when trying to send
         sla miss notification for a deleted task
@@ -436,9 +459,11 @@ class TestDagFileProcessor:
             SlaMiss(task_id="dummy_deleted", dag_id="test_sla_miss", 
execution_date=test_start_date)
         )
 
-        mock_log = mock.MagicMock()
-        dag_file_processor = DagFileProcessor(dag_ids=[], 
dag_directory=TEST_DAGS_FOLDER, log=mock_log)
-        dag_file_processor.manage_slas(dag=dag, session=session)
+        mock_dagbag = mock.Mock()
+        mock_dagbag.get_dag.return_value = dag
+        mock_get_dagbag.return_value = mock_dagbag
+
+        DagFileProcessor.manage_slas(dag_folder=dag.fileloc, 
dag_id="test_sla_miss", session=session)
 
     @patch.object(TaskInstance, "handle_failure")
     def test_execute_on_failure_callbacks(self, mock_ti_handle_failure):

Reply via email to