This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch openlineage-process-execution in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 449dd3c6719781527c2e0a551a512bad7aa5a05d Author: Maciej Obuchowski <[email protected]> AuthorDate: Fri May 31 14:47:28 2024 +0200 openlineage: execute extraction and message sending in separate process Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/providers/common/sql/operators/sql.py | 6 ++++ airflow/providers/openlineage/plugins/listener.py | 42 ++++++++++++++++++++--- airflow/providers/openlineage/sqlparser.py | 24 ++++++++++--- airflow/providers/openlineage/utils/sql.py | 8 +++++ airflow/providers/openlineage/utils/utils.py | 3 +- airflow/providers/snowflake/hooks/snowflake.py | 6 ++-- 6 files changed, 76 insertions(+), 13 deletions(-) diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index d50a6bf0f5..d8602eec93 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -309,6 +309,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator): except ImportError: return None + self.log.debug("Getting Hook for OL") + hook = self.get_db_hook() try: @@ -319,6 +321,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator): # OpenLineage provider release < 1.8.0 - we always use connection use_external_connection = True + self.log.error("External connection? %s", use_external_connection) + connection = hook.get_connection(getattr(hook, hook.conn_name_attr)) try: database_info = hook.get_openlineage_database_info(connection) @@ -338,6 +342,8 @@ class SQLExecuteQueryOperator(BaseSQLOperator): self.log.debug("%s failed to get database dialect", hook) return None + self.log.error("SQL result? %s", str(sql_parser)) + operator_lineage = sql_parser.generate_openlineage_metadata_from_sql( sql=self.sql, hook=hook, diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index e07c5507d8..57ba05fdbe 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -17,10 +17,12 @@ from __future__ import annotations import logging +import os from concurrent.futures import ProcessPoolExecutor from datetime import datetime from typing import TYPE_CHECKING +import psutil from openlineage.client.serde import Serde from packaging.version import Version @@ -37,6 +39,7 @@ from airflow.providers.openlineage.utils.utils import ( is_selective_lineage_enabled, print_warning, ) +from airflow.settings import configure_orm from airflow.stats import Stats from airflow.utils.timeout import timeout @@ -82,7 +85,7 @@ class OpenLineageListener: ) return - self.log.debug("OpenLineage listener got notification about task instance start") + self.log.debug("OpenLineage listener got notification about task instance start - fork version") dagrun = task_instance.dag_run task = task_instance.task if TYPE_CHECKING: @@ -155,7 +158,7 @@ class OpenLineageListener: len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_running() + self._fork_execute(on_running, "on_running") @hookimpl def on_task_instance_success( @@ -222,7 +225,7 @@ class OpenLineageListener: len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_success() + self._fork_execute(on_success, "on_success") if _IS_AIRFLOW_2_10_OR_HIGHER: @@ -317,10 +320,41 @@ class OpenLineageListener: len(Serde.to_json(redacted_event).encode("utf-8")), ) - on_failure() + self._fork_execute(on_failure, "on_failure") + + def _fork_execute(self, callable, callable_name: str): + self.log.debug("Will fork to execute OpenLineage process.") + if isinstance(callable_name, tuple): + self.log.error("WHY ITS TUPLE?") + pid = os.fork() + if pid: + process = psutil.Process(pid) + try: + self.log.debug("Waiting for process %s", pid) + process.wait(10) + except psutil.TimeoutExpired: + self.log.warning( + "OpenLineage process %s expired. This should not affect process execution.", pid + ) + process.kill() + except BaseException: + # Kill the process. + pass + try: + process.kill() + except Exception: + pass + self.log.info("Process with pid %s finished - parent", pid) + else: + configure_orm() + self.log.debug("After fork - new process with current PID.") + callable() + self.log.debug("Process with current pid finishes after %s", callable_name) + os._exit(0) @property def executor(self) -> ProcessPoolExecutor: + # Executor for dag_run listener def initializer(): # Re-configure the ORM engine as there are issues with multiple processes # if process calls Airflow DB. diff --git a/airflow/providers/openlineage/sqlparser.py b/airflow/providers/openlineage/sqlparser.py index f181ff8cce..38ad2a24a8 100644 --- a/airflow/providers/openlineage/sqlparser.py +++ b/airflow/providers/openlineage/sqlparser.py @@ -39,6 +39,7 @@ from airflow.providers.openlineage.utils.sql import ( get_table_schemas, ) from airflow.typing_compat import TypedDict +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -116,7 +117,7 @@ def from_table_meta( return Dataset(namespace=namespace, name=name if not is_uppercase else name.upper()) -class SQLParser: +class SQLParser(LoggingMixin): """Interface for openlineage-sql. :param dialect: dialect specific to the database @@ -124,11 +125,13 @@ class SQLParser: """ def __init__(self, dialect: str | None = None, default_schema: str | None = None) -> None: + super().__init__() self.dialect = dialect self.default_schema = default_schema def parse(self, sql: list[str] | str) -> SqlMeta | None: """Parse a single or a list of SQL statements.""" + self.log.error("PRE IN PARSER - %s %s %s", sql, self.dialect, self.default_schema) return parse(sql=sql, dialect=self.dialect, default_schema=self.default_schema) def parse_table_schemas( @@ -151,6 +154,7 @@ class SQLParser: "database": database or database_info.database, "use_flat_cross_db_query": database_info.use_flat_cross_db_query, } + self.log.info("PRE getting schemas for input and output tables") return get_table_schemas( hook, namespace, @@ -251,10 +255,15 @@ class SQLParser: :param sqlalchemy_engine: when passed, engine's dialect is used to compile SQL queries """ job_facets: dict[str, BaseFacet] = {"sql": SqlJobFacet(query=self.normalize_sql(sql))} - parse_result = self.parse(self.split_sql_string(sql)) + self.log.error("Pre split") + split = self.split_sql_string(sql) + self.log.error("POST SPLIT, PRE PARSE") + parse_result = self.parse(split) if not parse_result: + self.log.error("NOT_PARSED") return OperatorLineage(job_facets=job_facets) + self.log.error("Post call parser") run_facets: dict[str, BaseFacet] = {} if parse_result.errors: run_facets["extractionError"] = ExtractionErrorRunFacet( @@ -271,8 +280,11 @@ class SQLParser: ], ) + self.log.error("Before connection usage") + namespace = self.create_namespace(database_info=database_info) if use_connection: + self.log.error("Use connection") inputs, outputs = self.parse_table_schemas( hook=hook, inputs=parse_result.in_tables, @@ -283,6 +295,7 @@ class SQLParser: sqlalchemy_engine=sqlalchemy_engine, ) else: + self.log.error("Use only Parser Metadata") inputs, outputs = self.get_metadata_from_parser( inputs=parse_result.in_tables, outputs=parse_result.out_tables, @@ -335,9 +348,8 @@ class SQLParser: return split_statement(sql) return [obj for stmt in sql for obj in cls.split_sql_string(stmt) if obj != ""] - @classmethod def create_information_schema_query( - cls, + self, tables: list[DbTableMeta], normalize_name: Callable[[str], str], is_cross_db: bool, @@ -349,12 +361,14 @@ class SQLParser: sqlalchemy_engine: Engine | None = None, ) -> str: """Create SELECT statement to query information schema table.""" - tables_hierarchy = cls._get_tables_hierarchy( + self.log.info("Creating tables_hierarchy info for tables %s", tables) + tables_hierarchy = self._get_tables_hierarchy( tables, normalize_name=normalize_name, database=database, is_cross_db=is_cross_db, ) + self.log.info("Got tables_hierarchy: %s, going to create queries", tables_hierarchy) return create_information_schema_query( columns=information_schema_columns, information_schema_table_name=information_schema_table, diff --git a/airflow/providers/openlineage/utils/sql.py b/airflow/providers/openlineage/utils/sql.py index f959745b93..6bef6ec47c 100644 --- a/airflow/providers/openlineage/utils/sql.py +++ b/airflow/providers/openlineage/utils/sql.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging from collections import defaultdict from contextlib import closing from enum import IntEnum @@ -90,14 +91,21 @@ def get_table_schemas( if not in_query and not out_query: return [], [] + logging.getLogger(__name__).warning("CREATE_HOOK") + with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor: + logging.getLogger(__name__).warning("GOT HOOK") if in_query: + logging.getLogger(__name__).warning("PRE_IN_EXECUTE") cursor.execute(in_query) + logging.getLogger(__name__).warning("POST_IN_EXECUTE") in_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)] else: in_datasets = [] if out_query: + logging.getLogger(__name__).warning("PRE_OUT_EXECUTE") cursor.execute(out_query) + logging.getLogger(__name__).warning("POST_OUT_EXECUTE") out_datasets = [x.to_dataset(namespace, database, schema) for x in parse_query_result(cursor)] else: out_datasets = [] diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index ff6ad63970..d96b648082 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -406,4 +406,5 @@ def normalize_sql(sql: str | Iterable[str]): def should_use_external_connection(hook) -> bool: # TODO: Add checking overrides - return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"] + return False + # return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook"] diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 978bcf75e1..39e17be7b8 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -473,10 +473,10 @@ class SnowflakeHook(DbApiHook): from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import SQLParser - connection = self.get_connection(getattr(self, self.conn_name_attr)) - namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) - if self.query_ids: + self.log.info("Getting connector to get database info :sadge:") + connection = self.get_connection(getattr(self, self.conn_name_attr)) + namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) return OperatorLineage( run_facets={ "externalQuery": ExternalQueryRunFacet(
