This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 82c4ae283f Bugfix query count statistics when parsing DAF file (#41149)
82c4ae283f is described below

commit 82c4ae283fc48bf86c0a9bb76951f3bd4fc5f415
Author: max <[email protected]>
AuthorDate: Fri Aug 2 07:58:32 2024 +0000

    Bugfix query count statistics when parsing DAF file (#41149)
---
 airflow/dag_processing/processor.py | 318 ++++++++++++++++++------------------
 1 file changed, 162 insertions(+), 156 deletions(-)

diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index 68854dce1a..86db0b5b88 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -24,6 +24,7 @@ import threading
 import time
 import zipfile
 from contextlib import contextmanager, redirect_stderr, redirect_stdout, 
suppress
+from dataclasses import dataclass
 from datetime import timedelta
 from typing import TYPE_CHECKING, Generator, Iterable, Iterator
 
@@ -68,17 +69,25 @@ if TYPE_CHECKING:
     from airflow.models.operator import Operator
 
 
+@dataclass
+class _QueryCounter:
+    queries_number: int = 0
+
+    def inc(self):
+        self.queries_number += 1
+
+
 @contextmanager
-def count_queries(session: Session) -> Generator[list[int], None, None]:
+def count_queries(session: Session) -> Generator[_QueryCounter, None, None]:
     # using list allows to read the updated counter from what context manager 
returns
-    query_count: list[int] = [0]
+    counter: _QueryCounter = _QueryCounter()
 
     @event.listens_for(session, "do_orm_execute")
     def _count_db_queries(orm_execute_state):
-        nonlocal query_count
-        query_count[0] += 1
+        nonlocal counter
+        counter.inc()
 
-    yield query_count
+    yield counter
     event.remove(session, "do_orm_execute", _count_db_queries)
 
 
@@ -613,7 +622,7 @@ class DagFileProcessor(LoggingMixin):
         import_errors: dict[str, str],
         processor_subdir: str | None,
         session: Session = NEW_SESSION,
-    ) -> int:
+    ) -> None:
         """
         Update any import errors to be displayed in the UI.
 
@@ -626,55 +635,51 @@ class DagFileProcessor(LoggingMixin):
         """
         files_without_error = file_last_changed - import_errors.keys()
 
-        with count_queries(session) as query_count:
-            # Clear the errors of the processed files
-            # that no longer have errors
-            for dagbag_file in files_without_error:
-                session.execute(
-                    delete(ParseImportError)
-                    .where(ParseImportError.filename.startswith(dagbag_file))
-                    .execution_options(synchronize_session="fetch")
-                )
+        # Clear the errors of the processed files
+        # that no longer have errors
+        for dagbag_file in files_without_error:
+            session.execute(
+                delete(ParseImportError)
+                .where(ParseImportError.filename.startswith(dagbag_file))
+                .execution_options(synchronize_session="fetch")
+            )
 
-            # files that still have errors
-            existing_import_error_files = [x.filename for x in 
session.query(ParseImportError.filename).all()]
+        # files that still have errors
+        existing_import_error_files = [x.filename for x in 
session.query(ParseImportError.filename).all()]
 
-            # Add the errors of the processed files
-            for filename, stacktrace in import_errors.items():
-                if filename in existing_import_error_files:
-                    
session.query(ParseImportError).filter(ParseImportError.filename == 
filename).update(
-                        {"filename": filename, "timestamp": timezone.utcnow(), 
"stacktrace": stacktrace},
-                        synchronize_session="fetch",
-                    )
-                    # sending notification when an existing dag import error 
occurs
-                    get_listener_manager().hook.on_existing_dag_import_error(
-                        filename=filename, stacktrace=stacktrace
-                    )
-                else:
-                    session.add(
-                        ParseImportError(
-                            filename=filename,
-                            timestamp=timezone.utcnow(),
-                            stacktrace=stacktrace,
-                            processor_subdir=processor_subdir,
-                        )
-                    )
-                    # sending notification when a new dag import error occurs
-                    get_listener_manager().hook.on_new_dag_import_error(
-                        filename=filename, stacktrace=stacktrace
+        # Add the errors of the processed files
+        for filename, stacktrace in import_errors.items():
+            if filename in existing_import_error_files:
+                
session.query(ParseImportError).filter(ParseImportError.filename == 
filename).update(
+                    {"filename": filename, "timestamp": timezone.utcnow(), 
"stacktrace": stacktrace},
+                    synchronize_session="fetch",
+                )
+                # sending notification when an existing dag import error occurs
+                get_listener_manager().hook.on_existing_dag_import_error(
+                    filename=filename, stacktrace=stacktrace
+                )
+            else:
+                session.add(
+                    ParseImportError(
+                        filename=filename,
+                        timestamp=timezone.utcnow(),
+                        stacktrace=stacktrace,
+                        processor_subdir=processor_subdir,
                     )
-                (
-                    session.query(DagModel)
-                    .filter(DagModel.fileloc == filename)
-                    .update({"has_import_errors": True}, 
synchronize_session="fetch")
                 )
+                # sending notification when a new dag import error occurs
+                
get_listener_manager().hook.on_new_dag_import_error(filename=filename, 
stacktrace=stacktrace)
+            (
+                session.query(DagModel)
+                .filter(DagModel.fileloc == filename)
+                .update({"has_import_errors": True}, 
synchronize_session="fetch")
+            )
 
-            session.commit()
-            session.flush()
-        return query_count[0]
+        session.commit()
+        session.flush()
 
     @classmethod
-    def update_dag_warnings(cla, *, dagbag: DagBag) -> int:
+    def update_dag_warnings(cla, *, dagbag: DagBag) -> None:
         """Validate and raise exception if any task in a dag is using a 
non-existent pool."""
 
         def get_pools(dag) -> dict[str, set[str]]:
@@ -693,33 +698,31 @@ class DagFileProcessor(LoggingMixin):
     @provide_session
     def _validate_task_pools_and_update_dag_warnings(
         cls, pool_dict: dict[str, set[str]], dag_ids: set[str], session: 
Session = NEW_SESSION
-    ) -> int:
-        with count_queries(session) as query_count:
-            from airflow.models.pool import Pool
-
-            all_pools = {p.pool for p in Pool.get_pools(session)}
-            warnings: set[DagWarning] = set()
-            for dag_id, dag_pools in pool_dict.items():
-                nonexistent_pools = dag_pools - all_pools
-                if nonexistent_pools:
-                    warnings.add(
-                        DagWarning(
-                            dag_id,
-                            DagWarningType.NONEXISTENT_POOL,
-                            f"Dag '{dag_id}' references non-existent pools: 
{sorted(nonexistent_pools)!r}",
-                        )
+    ) -> None:
+        from airflow.models.pool import Pool
+
+        all_pools = {p.pool for p in Pool.get_pools(session)}
+        warnings: set[DagWarning] = set()
+        for dag_id, dag_pools in pool_dict.items():
+            nonexistent_pools = dag_pools - all_pools
+            if nonexistent_pools:
+                warnings.add(
+                    DagWarning(
+                        dag_id,
+                        DagWarningType.NONEXISTENT_POOL,
+                        f"Dag '{dag_id}' references non-existent pools: 
{sorted(nonexistent_pools)!r}",
                     )
+                )
 
-            stored_warnings = 
set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all())
+        stored_warnings = 
set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all())
 
-            for warning_to_delete in stored_warnings - warnings:
-                session.delete(warning_to_delete)
+        for warning_to_delete in stored_warnings - warnings:
+            session.delete(warning_to_delete)
 
-            for warning_to_add in warnings:
-                session.merge(warning_to_add)
-            session.flush()
-            session.commit()
-        return query_count[0]
+        for warning_to_add in warnings:
+            session.merge(warning_to_add)
+        session.flush()
+        session.commit()
 
     @classmethod
     @internal_api_call
@@ -730,7 +733,7 @@ class DagFileProcessor(LoggingMixin):
         callback_requests: list[CallbackRequest],
         unit_test_mode: bool,
         session: Session = NEW_SESSION,
-    ) -> int:
+    ) -> None:
         """
         Execute on failure callbacks.
 
@@ -742,57 +745,53 @@ class DagFileProcessor(LoggingMixin):
 
         :return: number of queries executed
         """
-        with count_queries(session) as query_count:
-            for request in callback_requests:
-                cls.logger().debug("Processing Callback Request: %s", request)
-                try:
-                    if isinstance(request, TaskCallbackRequest):
-                        cls._execute_task_callbacks(dagbag, request, 
unit_test_mode, session=session)
-                    elif isinstance(request, SlaCallbackRequest):
-                        if InternalApiConfig.get_use_internal_api():
-                            cls.logger().warning(
-                                "SlaCallbacks are not supported when the 
Internal API is enabled"
-                            )
-                        else:
-                            DagFileProcessor.manage_slas(dagbag.dag_folder, 
request.dag_id, session=session)
-                    elif isinstance(request, DagCallbackRequest):
-                        cls._execute_dag_callbacks(dagbag, request, 
session=session)
-                except Exception:
-                    cls.logger().exception(
-                        "Error executing %s callback for file: %s",
-                        request.__class__.__name__,
-                        request.full_filepath,
-                    )
-            session.flush()
-            session.commit()
-        return query_count[0]
+        for request in callback_requests:
+            cls.logger().debug("Processing Callback Request: %s", request)
+            try:
+                if isinstance(request, TaskCallbackRequest):
+                    cls._execute_task_callbacks(dagbag, request, 
unit_test_mode, session=session)
+                elif isinstance(request, SlaCallbackRequest):
+                    if InternalApiConfig.get_use_internal_api():
+                        cls.logger().warning(
+                            "SlaCallbacks are not supported when the Internal 
API is enabled"
+                        )
+                    else:
+                        DagFileProcessor.manage_slas(dagbag.dag_folder, 
request.dag_id, session=session)
+                elif isinstance(request, DagCallbackRequest):
+                    cls._execute_dag_callbacks(dagbag, request, 
session=session)
+            except Exception:
+                cls.logger().exception(
+                    "Error executing %s callback for file: %s",
+                    request.__class__.__name__,
+                    request.full_filepath,
+                )
+        session.flush()
+        session.commit()
 
     @classmethod
     @internal_api_call
     @provide_session
     def execute_callbacks_without_dag(
         cls, callback_requests: list[CallbackRequest], unit_test_mode: bool, 
session: Session = NEW_SESSION
-    ) -> int:
+    ) -> None:
         """
         Execute what callbacks we can as "best effort" when the dag cannot be 
found/had parse errors.
 
         This is so important so that tasks that failed when there is a parse
         error don't get stuck in queued state.
         """
-        with count_queries(session) as query_count:
-            for request in callback_requests:
-                cls.logger().debug("Processing Callback Request: %s", request)
-                if isinstance(request, TaskCallbackRequest):
-                    cls._execute_task_callbacks(None, request, unit_test_mode, 
session)
-                else:
-                    cls.logger().info(
-                        "Not executing %s callback for file %s as there was a 
dag parse error",
-                        request.__class__.__name__,
-                        request.full_filepath,
-                    )
-            session.flush()
-            session.commit()
-        return query_count[0]
+        for request in callback_requests:
+            cls.logger().debug("Processing Callback Request: %s", request)
+            if isinstance(request, TaskCallbackRequest):
+                cls._execute_task_callbacks(None, request, unit_test_mode, 
session)
+            else:
+                cls.logger().info(
+                    "Not executing %s callback for file %s as there was a dag 
parse error",
+                    request.__class__.__name__,
+                    request.full_filepath,
+                )
+        session.flush()
+        session.commit()
 
     @classmethod
     def _execute_dag_callbacks(cls, dagbag: DagBag, request: 
DagCallbackRequest, session: Session):
@@ -886,11 +885,13 @@ class DagFileProcessor(LoggingMixin):
             Stats.incr("dag_file_refresh_error", tags={"file_path": file_path})
             raise
 
+    @provide_session
     def process_file(
         self,
         file_path: str,
         callback_requests: list[CallbackRequest],
         pickle_dags: bool = False,
+        session: Session = NEW_SESSION,
     ) -> tuple[int, int, int]:
         """
         Process a Python file containing Airflow DAGs.
@@ -911,58 +912,63 @@ class DagFileProcessor(LoggingMixin):
         :return: number of dags found, count of import errors, last number of 
db queries
         """
         self.log.info("Processing file %s for tasks to queue", file_path)
-        try:
-            dagbag = DagFileProcessor._get_dagbag(file_path)
-        except Exception:
-            self.log.exception("Failed at reloading the DAG file %s", 
file_path)
-            Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": 
file_path})
-            return 0, 0, self._last_num_of_db_queries
 
-        if dagbag.dags:
-            self.log.info("DAG(s) %s retrieved from %s", ", ".join(map(repr, 
dagbag.dags)), file_path)
-        else:
-            self.log.warning("No viable dags retrieved from %s", file_path)
-            self._last_num_of_db_queries += 
DagFileProcessor.update_import_errors(
-                file_last_changed=dagbag.file_last_changed,
-                import_errors=dagbag.import_errors,
-                processor_subdir=self._dag_directory,
-            )
-            if callback_requests:
-                # If there were callback requests for this file but there was a
-                # parse error we still need to progress the state of TIs,
-                # otherwise they might be stuck in queued/running for ever!
-                self._last_num_of_db_queries += 
DagFileProcessor.execute_callbacks_without_dag(
-                    callback_requests, self.UNIT_TEST_MODE
+        with count_queries(session) as query_counter:
+            try:
+                dagbag = DagFileProcessor._get_dagbag(file_path)
+            except Exception:
+                self.log.exception("Failed at reloading the DAG file %s", 
file_path)
+                Stats.incr("dag_file_refresh_error", 1, 1, tags={"file_path": 
file_path})
+                return 0, 0, self._cache_last_num_of_db_queries(query_counter)
+
+            if dagbag.dags:
+                self.log.info("DAG(s) %s retrieved from %s", ", 
".join(map(repr, dagbag.dags)), file_path)
+            else:
+                self.log.warning("No viable dags retrieved from %s", file_path)
+                DagFileProcessor.update_import_errors(
+                    file_last_changed=dagbag.file_last_changed,
+                    import_errors=dagbag.import_errors,
+                    processor_subdir=self._dag_directory,
                 )
-            return 0, len(dagbag.import_errors), self._last_num_of_db_queries
-
-        self._last_num_of_db_queries += self.execute_callbacks(dagbag, 
callback_requests, self.UNIT_TEST_MODE)
+                if callback_requests:
+                    # If there were callback requests for this file but there 
was a
+                    # parse error we still need to progress the state of TIs,
+                    # otherwise they might be stuck in queued/running for ever!
+                    
DagFileProcessor.execute_callbacks_without_dag(callback_requests, 
self.UNIT_TEST_MODE)
+                return 0, len(dagbag.import_errors), 
self._cache_last_num_of_db_queries(query_counter)
+
+            self.execute_callbacks(dagbag, callback_requests, 
self.UNIT_TEST_MODE)
+
+            serialize_errors = DagFileProcessor.save_dag_to_db(
+                dags=dagbag.dags,
+                dag_directory=self._dag_directory,
+                pickle_dags=pickle_dags,
+            )
 
-        serialize_errors = DagFileProcessor.save_dag_to_db(
-            dags=dagbag.dags,
-            dag_directory=self._dag_directory,
-            pickle_dags=pickle_dags,
-        )
+            dagbag.import_errors.update(dict(serialize_errors))
 
-        dagbag.import_errors.update(dict(serialize_errors))
+            # Record import errors into the ORM
+            try:
+                DagFileProcessor.update_import_errors(
+                    file_last_changed=dagbag.file_last_changed,
+                    import_errors=dagbag.import_errors,
+                    processor_subdir=self._dag_directory,
+                )
+            except Exception:
+                self.log.exception("Error logging import errors!")
 
-        # Record import errors into the ORM
-        try:
-            self._last_num_of_db_queries += 
DagFileProcessor.update_import_errors(
-                file_last_changed=dagbag.file_last_changed,
-                import_errors=dagbag.import_errors,
-                processor_subdir=self._dag_directory,
-            )
-        except Exception:
-            self.log.exception("Error logging import errors!")
+            # Record DAG warnings in the metadatabase.
+            try:
+                self.update_dag_warnings(dagbag=dagbag)
+            except Exception:
+                self.log.exception("Error logging DAG warnings.")
 
-        # Record DAG warnings in the metadatabase.
-        try:
-            self._last_num_of_db_queries += 
self.update_dag_warnings(dagbag=dagbag)
-        except Exception:
-            self.log.exception("Error logging DAG warnings.")
+        return len(dagbag.dags), len(dagbag.import_errors), 
self._cache_last_num_of_db_queries(query_counter)
 
-        return len(dagbag.dags), len(dagbag.import_errors), 
self._last_num_of_db_queries
+    def _cache_last_num_of_db_queries(self, query_counter: _QueryCounter | 
None = None):
+        if query_counter:
+            self._last_num_of_db_queries = query_counter.queries_number
+        return self._last_num_of_db_queries
 
     @staticmethod
     @internal_api_call

Reply via email to