kaxil commented on code in PR #44972:
URL: https://github.com/apache/airflow/pull/44972#discussion_r1888209958


##########
airflow/dag_processing/processor.py:
##########
@@ -16,610 +16,229 @@
 # under the License.
 from __future__ import annotations
 
-import importlib
-import logging
 import os
-import signal
-import threading
-import time
-import zipfile
-from collections.abc import Generator, Iterable
-from contextlib import contextmanager, redirect_stderr, redirect_stdout, 
suppress
-from dataclasses import dataclass
-from typing import TYPE_CHECKING
-
-from setproctitle import setproctitle
-from sqlalchemy import event
-
-from airflow import settings
+import sys
+import traceback
+from collections.abc import Generator
+from typing import TYPE_CHECKING, Annotated, Callable, Literal, Union
+
+import attrs
+import pydantic
+
 from airflow.callbacks.callback_requests import (
+    CallbackRequest,
     DagCallbackRequest,
     TaskCallbackRequest,
 )
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
-from airflow.models.dag import DAG
 from airflow.models.dagbag import DagBag
-from airflow.models.pool import Pool
-from airflow.models.serialized_dag import SerializedDagModel
-from airflow.models.taskinstance import TaskInstance, _run_finished_callback
+from airflow.sdk.execution_time.comms import GetConnection, GetVariable
+from airflow.sdk.execution_time.supervisor import WatchedSubprocess
+from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
 from airflow.stats import Stats
-from airflow.utils import timezone
-from airflow.utils.file import iter_airflow_imports, might_contain_dag
-from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, 
set_context
-from airflow.utils.mixins import MultiprocessingStartMethodMixin
-from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
-    import multiprocessing
-    from datetime import datetime
-    from multiprocessing.connection import Connection as 
MultiprocessingConnection
-
-    from sqlalchemy.orm.session import Session
-
-    from airflow.callbacks.callback_requests import CallbackRequest
-    from airflow.models.operator import Operator
+    from airflow.typing_compat import Self
+    from airflow.utils.context import Context
+
+
+def _parse_file_entrypoint():
+    import os
+
+    import structlog
+
+    from airflow.sdk.execution_time import task_runner
+    # Parse DAG file, send JSON back up!
+
+    comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, 
DagFileParsingResult](
+        input=sys.stdin,
+        decoder=pydantic.TypeAdapter[DagFileParseRequest](DagFileParseRequest),
+    )
+    msg = comms_decoder.get_message()
+    comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", 
buffering=0)
+
+    log = structlog.get_logger(logger_name="task")
+
+    result = _parse_file(msg, log)
+    comms_decoder.send_request(log, result)
+
+
+def _parse_file(msg: DagFileParseRequest, log):
+    # TODO: Set known_pool names on DagBag!
+    bag = DagBag(
+        dag_folder=msg.file,
+        include_examples=False,
+        safe_mode=True,
+        load_op_links=False,
+    )
+    serialized_dags, serialization_import_errors = _serialize_dags(bag, log)
+    bag.import_errors.update(serialization_import_errors)
+    dags = [LazyDeserializedDAG(data=serdag) for serdag in serialized_dags]
+    result = DagFileParsingResult(
+        fileloc=msg.file,
+        serialized_dags=dags,
+        import_errors=bag.import_errors,
+        # TODO: Make `bag.dag_warnings` not return SQLA model objects
+        warnings=[],
+    )
+
+    if msg.callback_requests:
+        _execute_callbacks(bag, msg.callback_requests, log)
+    return result
+
+
+def _serialize_dags(bag, log):
+    serialization_import_errors = {}
+    serialized_dags = []
+    for dag in bag.dags.values():
+        try:
+            serialized_dag = SerializedDAG.to_dict(dag)
+            serialized_dags.append(serialized_dag)
+        except Exception:
+            log.exception("Failed to serialize DAG: %s", dag.fileloc)
+            dagbag_import_error_traceback_depth = conf.getint(
+                "core", "dagbag_import_error_traceback_depth", fallback=None
+            )
+            serialization_import_errors[dag.fileloc] = traceback.format_exc(
+                limit=-dagbag_import_error_traceback_depth
+            )
+    return serialized_dags, serialization_import_errors
 
 
-@dataclass
-class _QueryCounter:
-    queries_number: int = 0
+def _execute_callbacks(dagbag: DagBag, callback_requests: 
list[CallbackRequest], log):
+    for request in callback_requests:
+        log.debug("Processing Callback Request", request=request)
+        if isinstance(request, TaskCallbackRequest):
+            raise NotImplementedError("Haven't coded Task callback yet!")
+            # _execute_task_callbacks(dagbag, request)
+        elif isinstance(request, DagCallbackRequest):
+            _execute_dag_callbacks(dagbag, request, log)
 
-    def inc(self):
-        self.queries_number += 1
 
+def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log):
+    dag = dagbag.dags[request.dag_id]
 
-@contextmanager
-def count_queries(session: Session) -> Generator[_QueryCounter, None, None]:
-    # using list allows to read the updated counter from what context manager 
returns
-    counter: _QueryCounter = _QueryCounter()
+    callbacks = dag.on_failure_callback if request.is_failure_callback else 
dag.on_success_callback
+    if not callbacks:
+        log.warning("Callback requested, but dag didn't have any", 
dag_id=request.dag_id)
+        return
 
-    @event.listens_for(session, "do_orm_execute")
-    def _count_db_queries(orm_execute_state):
-        nonlocal counter
-        counter.inc()
+    callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
+    # TODO:We need a proper context object!
+    context: Context = {}

Review Comment:
   So we need a method on dag object then, similar to `dag.fetch_callback` that 
gets TI and passes in context



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to