ashb commented on code in PR #44972:
URL: https://github.com/apache/airflow/pull/44972#discussion_r1891617511


##########
airflow/dag_processing/processor.py:
##########
@@ -16,610 +16,237 @@
 # under the License.
 from __future__ import annotations
 
-import importlib
-import logging
 import os
-import signal
-import threading
-import time
-import zipfile
-from collections.abc import Generator, Iterable
-from contextlib import contextmanager, redirect_stderr, redirect_stdout, 
suppress
-from dataclasses import dataclass
-from typing import TYPE_CHECKING
-
-from setproctitle import setproctitle
-from sqlalchemy import event
-
-from airflow import settings
+import sys
+import traceback
+from collections.abc import Generator
+from typing import TYPE_CHECKING, Annotated, Callable, Literal, Union
+
+import attrs
+from pydantic import BaseModel, Field, TypeAdapter
+
 from airflow.callbacks.callback_requests import (
+    CallbackRequest,
     DagCallbackRequest,
     TaskCallbackRequest,
 )
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
-from airflow.models.dag import DAG
 from airflow.models.dagbag import DagBag
-from airflow.models.pool import Pool
-from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import TaskInstance, _run_finished_callback
+from airflow.sdk.execution_time.comms import GetConnection, GetVariable
+from airflow.sdk.execution_time.supervisor import WatchedSubprocess
+from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
 from airflow.stats import Stats
-from airflow.utils import timezone
-from airflow.utils.file import iter_airflow_imports, might_contain_dag
-from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, 
set_context
-from airflow.utils.mixins import MultiprocessingStartMethodMixin
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
-    import multiprocessing
-    from datetime import datetime
-    from multiprocessing.connection import Connection as 
MultiprocessingConnection
+    from structlog.typing import FilteringBoundLogger
 
-    from sqlalchemy.orm.session import Session
+    from airflow.typing_compat import Self
+    from airflow.utils.context import Context
 
-    from airflow.callbacks.callback_requests import CallbackRequest
-    from airflow.models.operator import Operator
 
+def _parse_file_entrypoint():
+    import os
 
-@dataclass
-class _QueryCounter:
-    queries_number: int = 0
+    import structlog
 
-    def inc(self):
-        self.queries_number += 1
+    from airflow.sdk.execution_time import task_runner
+    # Parse DAG file, send JSON back up!
 
+    comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, 
DagFileParsingResult](
+        input=sys.stdin,
+        decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest),
+    )
+    msg = comms_decoder.get_message()
+    comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
 
-@contextmanager
-def count_queries(session: Session) -> Generator[_QueryCounter, None, None]:
-    # using list allows to read the updated counter from what context manager 
returns
-    counter: _QueryCounter = _QueryCounter()
+    log = structlog.get_logger(logger_name="task")
 
-    @event.listens_for(session, "do_orm_execute")
-    def _count_db_queries(orm_execute_state):
-        nonlocal counter
-        counter.inc()
+    result = _parse_file(msg, log)
+    comms_decoder.send_request(log, result)
 
-    yield counter
-    event.remove(session, "do_orm_execute", _count_db_queries)
 
+def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> 
DagFileParsingResult:
+    # TODO: Set known_pool names on DagBag!
+    bag = DagBag(
+        dag_folder=msg.file,
+        include_examples=False,
+        safe_mode=True,
+        load_op_links=False,
+    )
+    serialized_dags, serialization_import_errors = _serialize_dags(bag, log)
+    bag.import_errors.update(serialization_import_errors)
+    dags = [LazyDeserializedDAG(data=serdag) for serdag in serialized_dags]
+    result = DagFileParsingResult(
+        fileloc=msg.file,
+        serialized_dags=dags,
+        import_errors=bag.import_errors,
+        # TODO: Make `bag.dag_warnings` not return SQLA model objects
+        warnings=[],
+    )
 
-class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
-    """
-    Runs DAG processing in a separate process using DagFileProcessor.
+    if msg.callback_requests:
+        _execute_callbacks(bag, msg.callback_requests, log)
+    return result
 
-    :param file_path: a Python file containing Airflow DAG definitions
-    :param callback_requests: failure callback to execute
-    """
-
-    # Counter that increments every time an instance of this class is created
-    class_creation_counter = 0
-
-    def __init__(
-        self,
-        file_path: str,
-        dag_directory: str,
-        callback_requests: list[CallbackRequest],
-    ):
-        super().__init__()
-        self._file_path = file_path
-        self._dag_directory = dag_directory
-        self._callback_requests = callback_requests
-
-        # The process that was launched to process the given DAG file.
-        self._process: multiprocessing.process.BaseProcess | None = None
-        # The result of DagFileProcessor.process_file(file_path).
-        self._result: tuple[int, int, int] | None = None
-        # Whether the process is done running.
-        self._done = False
-        # When the process started.
-        self._start_time: datetime | None = None
-        # This ID is used to uniquely name the process / thread that's launched
-        # by this processor instance
-        self._instance_id = DagFileProcessorProcess.class_creation_counter
-
-        self._parent_channel: MultiprocessingConnection | None = None
-        DagFileProcessorProcess.class_creation_counter += 1
-
-    @property
-    def file_path(self) -> str:
-        return self._file_path
-
-    @staticmethod
-    def _run_file_processor(
-        result_channel: MultiprocessingConnection,
-        parent_channel: MultiprocessingConnection,
-        file_path: str,
-        thread_name: str,
-        dag_directory: str,
-        callback_requests: list[CallbackRequest],
-        known_pools: set[str] | None = None,
-    ) -> None:
-        """
-        Process the given file.
-
-        :param result_channel: the connection to use for passing back the 
result
-        :param parent_channel: the parent end of the channel to close in the 
child
-        :param file_path: the file to process
-        :param thread_name: the name to use for the process that is launched
-        :param callback_requests: failure callback to execute
-        :return: the process that was launched
-        """
-        # This helper runs in the newly created process
-        log: logging.Logger = logging.getLogger("airflow.processor")
-
-        # Since we share all open FDs from the parent, we need to close the 
parent side of the pipe here in
-        # the child, else it won't get closed properly until we exit.
-        parent_channel.close()
-        del parent_channel
-
-        set_context(log, file_path)
-        setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
-
-        def _handle_dag_file_processing():
-            # Re-configure the ORM engine as there are issues with multiple 
processes
-            settings.configure_orm()
-
-            # Change the thread name to differentiate log lines. This is
-            # really a separate process, but changing the name of the
-            # process doesn't work, so changing the thread name instead.
-            threading.current_thread().name = thread_name
-
-            log.info("Started process (PID=%s) to work on %s", os.getpid(), 
file_path)
-            dag_file_processor = DagFileProcessor(dag_directory=dag_directory, 
log=log)
-            result: tuple[int, int, int] = dag_file_processor.process_file(
-                file_path=file_path,
-                callback_requests=callback_requests,
-                known_pools=known_pools,
-            )
-            result_channel.send(result)
 
+def _serialize_dags(bag: DagBag, log: FilteringBoundLogger) -> 
tuple[list[dict], dict[str, str]]:
+    serialization_import_errors = {}
+    serialized_dags = []
+    for dag in bag.dags.values():
         try:
-            DAG_PROCESSOR_LOG_TARGET = conf.get_mandatory_value("logging", 
"DAG_PROCESSOR_LOG_TARGET")
-            if DAG_PROCESSOR_LOG_TARGET == "stdout":
-                with Stats.timer() as timer:
-                    _handle_dag_file_processing()
-            else:
-                # The following line ensures that stdout goes to the same 
destination as the logs. If stdout
-                # gets sent to logs and logs are sent to stdout, this leads to 
an infinite loop. This
-                # necessitates this conditional based on the value of 
DAG_PROCESSOR_LOG_TARGET.
-                with (
-                    redirect_stdout(StreamLogWriter(log, logging.INFO)),
-                    redirect_stderr(StreamLogWriter(log, logging.WARNING)),
-                    Stats.timer() as timer,
-                ):
-                    _handle_dag_file_processing()
-            log.info("Processing %s took %.3f seconds", file_path, 
timer.duration)
+            serialized_dag = SerializedDAG.to_dict(dag)
+            serialized_dags.append(serialized_dag)
         except Exception:
-            # Log exceptions through the logging framework.
-            log.exception("Got an exception! Propagating...")
-            raise
-        finally:
-            # We re-initialized the ORM within this Process above so we need to
-            # tear it down manually here
-            settings.dispose_orm()
-
-            result_channel.close()
-
-    def start(self) -> None:
-        """Launch the process and start processing the DAG."""
-        if conf.getboolean("scheduler", "parsing_pre_import_modules", 
fallback=True):
-            # Read the file to pre-import airflow modules used.
-            # This prevents them from being re-imported from zero in each 
"processing" process
-            # and saves CPU time and memory.
-            zip_file_paths = []
-            if zipfile.is_zipfile(self.file_path):
-                try:
-                    with zipfile.ZipFile(self.file_path) as z:
-                        zip_file_paths.extend(
-                            [
-                                os.path.join(self.file_path, info.filename)
-                                for info in z.infolist()
-                                if might_contain_dag(info.filename, True, z)
-                            ]
-                        )
-                except zipfile.BadZipFile as err:
-                    self.log.error("There was an err accessing %s, %s", 
self.file_path, err)
-            if zip_file_paths:
-                self.import_modules(zip_file_paths)
-            else:
-                self.import_modules(self.file_path)
-
-        context = self._get_multiprocessing_context()
-
-        pool_names = {p.pool for p in Pool.get_pools()}
-
-        _parent_channel, _child_channel = context.Pipe(duplex=False)
-        process = context.Process(
-            target=type(self)._run_file_processor,
-            args=(
-                _child_channel,
-                _parent_channel,
-                self.file_path,
-                f"DagFileProcessor{self._instance_id}",
-                self._dag_directory,
-                self._callback_requests,
-                pool_names,
-            ),
-            name=f"DagFileProcessor{self._instance_id}-Process",
+            log.exception("Failed to serialize DAG: %s", dag.fileloc)
+            dagbag_import_error_traceback_depth = conf.getint(
+                "core", "dagbag_import_error_traceback_depth", fallback=None
+            )
+            serialization_import_errors[dag.fileloc] = traceback.format_exc(
+                limit=-dagbag_import_error_traceback_depth
+            )
+    return serialized_dags, serialization_import_errors
+
+
+def _execute_callbacks(
+    dagbag: DagBag, callback_requests: list[CallbackRequest], log: 
FilteringBoundLogger
+) -> None:
+    for request in callback_requests:
+        log.debug("Processing Callback Request", request=request)
+        if isinstance(request, TaskCallbackRequest):
+            raise NotImplementedError("Haven't coded Task callback yet!")
+            # _execute_task_callbacks(dagbag, request)

Review Comment:
   I'll link to https://github.com/apache/airflow/issues/44354



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to