This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch openlineage-process-execution
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 449dd3c6719781527c2e0a551a512bad7aa5a05d
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Fri May 31 14:47:28 2024 +0200

    openlineage: execute extraction and message sending in separate process
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/providers/common/sql/operators/sql.py     |  6 ++++
 airflow/providers/openlineage/plugins/listener.py | 42 ++++++++++++++++++++---
 airflow/providers/openlineage/sqlparser.py        | 24 ++++++++++---
 airflow/providers/openlineage/utils/sql.py        |  8 +++++
 airflow/providers/openlineage/utils/utils.py      |  3 +-
 airflow/providers/snowflake/hooks/snowflake.py    |  6 ++--
 6 files changed, 76 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index d50a6bf0f5..d8602eec93 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -309,6 +309,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
         except ImportError:
             return None
 
+        self.log.debug("Getting Hook for OL")
+
         hook = self.get_db_hook()
 
         try:
@@ -319,6 +321,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             # OpenLineage provider release < 1.8.0 - we always use connection
             use_external_connection = True
 
+        self.log.error("External connection? %s", use_external_connection)
+
         connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
         try:
             database_info = hook.get_openlineage_database_info(connection)
@@ -338,6 +342,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             self.log.debug("%s failed to get database dialect", hook)
             return None
 
+        self.log.error("SQL result? %s", str(sql_parser))
+
         operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
             sql=self.sql,
             hook=hook,
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index e07c5507d8..57ba05fdbe 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -17,10 +17,12 @@
 from __future__ import annotations
 
 import logging
+import os
 from concurrent.futures import ProcessPoolExecutor
 from datetime import datetime
 from typing import TYPE_CHECKING
 
+import psutil
 from openlineage.client.serde import Serde
 from packaging.version import Version
 
@@ -37,6 +39,7 @@ from airflow.providers.openlineage.utils.utils import (
     is_selective_lineage_enabled,
     print_warning,
 )
+from airflow.settings import configure_orm
 from airflow.stats import Stats
 from airflow.utils.timeout import timeout
 
@@ -82,7 +85,7 @@ class OpenLineageListener:
             )
             return
 
-        self.log.debug("OpenLineage listener got notification about task 
instance start")
+        self.log.debug("OpenLineage listener got notification about task 
instance start - fork version")
         dagrun = task_instance.dag_run
         task = task_instance.task
         if TYPE_CHECKING:
@@ -155,7 +158,7 @@ class OpenLineageListener:
                 len(Serde.to_json(redacted_event).encode("utf-8")),
             )
 
-        on_running()
+        self._fork_execute(on_running, "on_running")
 
     @hookimpl
     def on_task_instance_success(
@@ -222,7 +225,7 @@ class OpenLineageListener:
                 len(Serde.to_json(redacted_event).encode("utf-8")),
             )
 
-        on_success()
+        self._fork_execute(on_success, "on_success")
 
     if _IS_AIRFLOW_2_10_OR_HIGHER:
 
@@ -317,10 +320,41 @@ class OpenLineageListener:
                 len(Serde.to_json(redacted_event).encode("utf-8")),
             )
 
-        on_failure()
+        self._fork_execute(on_failure, "on_failure")
+
+    def _fork_execute(self, callable, callable_name: str):
+        self.log.debug("Will fork to execute OpenLineage process.")
+        if isinstance(callable_name, tuple):
+            self.log.error("WHY ITS TUPLE?")
+        pid = os.fork()
+        if pid:
+            process = psutil.Process(pid)
+            try:
+                self.log.debug("Waiting for process %s", pid)
+                process.wait(10)
+            except psutil.TimeoutExpired:
+                self.log.warning(
+                    "OpenLineage process %s expired. This should not affect 
process execution.", pid
+                )
+                process.kill()
+            except BaseException:
+                # Kill the process.
+                pass
+                try:
+                    process.kill()
+                except Exception:
+                    pass
+            self.log.info("Process with pid %s finished - parent", pid)
+        else:
+            configure_orm()
+            self.log.debug("After fork - new process with current PID.")
+            callable()
+            self.log.debug("Process with current pid finishes after %s", 
callable_name)
+            os._exit(0)
 
     @property
     def executor(self) -> ProcessPoolExecutor:
+        # Executor for dag_run listener
         def initializer():
             # Re-configure the ORM engine as there are issues with multiple 
processes
             # if process calls Airflow DB.
diff --git a/airflow/providers/openlineage/sqlparser.py 
b/airflow/providers/openlineage/sqlparser.py
index f181ff8cce..38ad2a24a8 100644
--- a/airflow/providers/openlineage/sqlparser.py
+++ b/airflow/providers/openlineage/sqlparser.py
@@ -39,6 +39,7 @@ from airflow.providers.openlineage.utils.sql import (
     get_table_schemas,
 )
 from airflow.typing_compat import TypedDict
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import Engine
@@ -116,7 +117,7 @@ def from_table_meta(
     return Dataset(namespace=namespace, name=name if not is_uppercase else 
name.upper())
 
 
-class SQLParser:
+class SQLParser(LoggingMixin):
     """Interface for openlineage-sql.
 
     :param dialect: dialect specific to the database
@@ -124,11 +125,13 @@ class SQLParser:
     """
 
     def __init__(self, dialect: str | None = None, default_schema: str | None 
= None) -> None:
+        super().__init__()
         self.dialect = dialect
         self.default_schema = default_schema
 
     def parse(self, sql: list[str] | str) -> SqlMeta | None:
         """Parse a single or a list of SQL statements."""
+        self.log.error("PRE IN PARSER - %s %s %s", sql, self.dialect, 
self.default_schema)
         return parse(sql=sql, dialect=self.dialect, 
default_schema=self.default_schema)
 
     def parse_table_schemas(
@@ -151,6 +154,7 @@ class SQLParser:
             "database": database or database_info.database,
             "use_flat_cross_db_query": database_info.use_flat_cross_db_query,
         }
+        self.log.info("PRE getting schemas for input and output tables")
         return get_table_schemas(
             hook,
             namespace,
@@ -251,10 +255,15 @@ class SQLParser:
         :param sqlalchemy_engine: when passed, engine's dialect is used to 
compile SQL queries
         """
         job_facets: dict[str, BaseFacet] = {"sql": 
SqlJobFacet(query=self.normalize_sql(sql))}
-        parse_result = self.parse(self.split_sql_string(sql))
+        self.log.error("Pre split")
+        split = self.split_sql_string(sql)
+        self.log.error("POST SPLIT, PRE PARSE")
+        parse_result = self.parse(split)
         if not parse_result:
+            self.log.error("NOT_PARSED")
             return OperatorLineage(job_facets=job_facets)
 
+        self.log.error("Post call parser")
         run_facets: dict[str, BaseFacet] = {}
         if parse_result.errors:
             run_facets["extractionError"] = ExtractionErrorRunFacet(
@@ -271,8 +280,11 @@ class SQLParser:
                 ],
             )
 
+        self.log.error("Before connection usage")
+
         namespace = self.create_namespace(database_info=database_info)
         if use_connection:
+            self.log.error("Use connection")
             inputs, outputs = self.parse_table_schemas(
                 hook=hook,
                 inputs=parse_result.in_tables,
@@ -283,6 +295,7 @@ class SQLParser:
                 sqlalchemy_engine=sqlalchemy_engine,
             )
         else:
+            self.log.error("Use only Parser Metadata")
             inputs, outputs = self.get_metadata_from_parser(
                 inputs=parse_result.in_tables,
                 outputs=parse_result.out_tables,
@@ -335,9 +348,8 @@ class SQLParser:
             return split_statement(sql)
         return [obj for stmt in sql for obj in cls.split_sql_string(stmt) if 
obj != ""]
 
-    @classmethod
     def create_information_schema_query(
-        cls,
+        self,
         tables: list[DbTableMeta],
         normalize_name: Callable[[str], str],
         is_cross_db: bool,
@@ -349,12 +361,14 @@ class SQLParser:
         sqlalchemy_engine: Engine | None = None,
     ) -> str:
         """Create SELECT statement to query information schema table."""
-        tables_hierarchy = cls._get_tables_hierarchy(
+        self.log.info("Creating tables_hierarchy info for tables %s", tables)
+        tables_hierarchy = self._get_tables_hierarchy(
             tables,
             normalize_name=normalize_name,
             database=database,
             is_cross_db=is_cross_db,
         )
+        self.log.info("Got tables_hierarchy: %s, going to create queries", 
tables_hierarchy)
         return create_information_schema_query(
             columns=information_schema_columns,
             information_schema_table_name=information_schema_table,
diff --git a/airflow/providers/openlineage/utils/sql.py 
b/airflow/providers/openlineage/utils/sql.py
index f959745b93..6bef6ec47c 100644
--- a/airflow/providers/openlineage/utils/sql.py
+++ b/airflow/providers/openlineage/utils/sql.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from collections import defaultdict
 from contextlib import closing
 from enum import IntEnum
@@ -90,14 +91,21 @@ def get_table_schemas(
     if not in_query and not out_query:
         return [], []
 
+    logging.getLogger(__name__).warning("CREATE_HOOK")
+
     with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
+        logging.getLogger(__name__).warning("GOT HOOK")
         if in_query:
+            logging.getLogger(__name__).warning("PRE_IN_EXECUTE")
             cursor.execute(in_query)
+            logging.getLogger(__name__).warning("POST_IN_EXECUTE")
             in_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
         else:
             in_datasets = []
         if out_query:
+            logging.getLogger(__name__).warning("PRE_OUT_EXECUTE")
             cursor.execute(out_query)
+            logging.getLogger(__name__).warning("POST_OUT_EXECUTE")
             out_datasets = [x.to_dataset(namespace, database, schema) for x in 
parse_query_result(cursor)]
         else:
             out_datasets = []
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index ff6ad63970..d96b648082 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -406,4 +406,5 @@ def normalize_sql(sql: str | Iterable[str]):
 
 def should_use_external_connection(hook) -> bool:
     # TODO: Add checking overrides
-    return hook.__class__.__name__ not in ["SnowflakeHook", 
"SnowflakeSqlApiHook"]
+    return False
+    # return hook.__class__.__name__ not in ["SnowflakeHook", 
"SnowflakeSqlApiHook"]
diff --git a/airflow/providers/snowflake/hooks/snowflake.py 
b/airflow/providers/snowflake/hooks/snowflake.py
index 978bcf75e1..39e17be7b8 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -473,10 +473,10 @@ class SnowflakeHook(DbApiHook):
         from airflow.providers.openlineage.extractors import OperatorLineage
         from airflow.providers.openlineage.sqlparser import SQLParser
 
-        connection = self.get_connection(getattr(self, self.conn_name_attr))
-        namespace = 
SQLParser.create_namespace(self.get_openlineage_database_info(connection))
-
         if self.query_ids:
+            self.log.info("Getting connector to get database info :sadge:")
+            connection = self.get_connection(getattr(self, 
self.conn_name_attr))
+            namespace = 
SQLParser.create_namespace(self.get_openlineage_database_info(connection))
             return OperatorLineage(
                 run_facets={
                     "externalQuery": ExternalQueryRunFacet(

Reply via email to