This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 29ace334ac9 chore: Move OpenLineage methods to BaseSQLOperator (#58897)
29ace334ac9 is described below

commit 29ace334ac9708fd32183bd3b6c115e503035e7e
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Dec 2 00:34:14 2025 +0100

    chore: Move OpenLineage methods to BaseSQLOperator (#58897)
---
 .../airflow/providers/common/sql/operators/sql.py  | 209 ++++++++++++---------
 .../tests/unit/common/sql/operators/test_sql.py    |  20 +-
 2 files changed, 137 insertions(+), 92 deletions(-)

diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py 
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index 1e82a01bd1e..3132fcf1cbc 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -213,6 +213,86 @@ class BaseSQLOperator(BaseOperator):
             raise AirflowException(exception_string)
         raise AirflowFailException(exception_string)
 
+    def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
+        """Generate OpenLineage facets on start for SQL operators."""
+        try:
+            from airflow.providers.openlineage.extractors import 
OperatorLineage
+            from airflow.providers.openlineage.sqlparser import SQLParser
+        except ImportError:
+            self.log.debug("OpenLineage could not import required classes. 
Skipping.")
+            return None
+
+        sql = getattr(self, "sql", None)
+        if not sql:
+            self.log.debug("OpenLineage could not find 'sql' attribute on 
`%s`.", type(self).__name__)
+            return OperatorLineage()
+
+        hook = self.get_db_hook()
+        try:
+            from airflow.providers.openlineage.utils.utils import 
should_use_external_connection
+
+            use_external_connection = should_use_external_connection(hook)
+        except ImportError:
+            # OpenLineage provider release < 1.8.0 - we always use connection
+            use_external_connection = True
+
+        connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
+        try:
+            database_info = hook.get_openlineage_database_info(connection)
+        except AttributeError:
+            self.log.debug("%s has no database info provided", hook)
+            database_info = None
+
+        if database_info is None:
+            self.log.debug("OpenLineage could not retrieve database 
information. Skipping.")
+            return OperatorLineage()
+
+        try:
+            sql_parser = SQLParser(
+                dialect=hook.get_openlineage_database_dialect(connection),
+                default_schema=hook.get_openlineage_default_schema(),
+            )
+        except AttributeError:
+            self.log.debug("%s failed to get database dialect", hook)
+            return None
+
+        operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
+            sql=sql,
+            hook=hook,
+            database_info=database_info,
+            database=self.database,
+            sqlalchemy_engine=hook.get_sqlalchemy_engine(),
+            use_connection=use_external_connection,
+        )
+
+        return operator_lineage
+
+    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage | None:
+        """Generate OpenLineage facets when task completes."""
+        try:
+            from airflow.providers.openlineage.extractors import 
OperatorLineage
+        except ImportError:
+            self.log.debug("OpenLineage could not import required classes. 
Skipping.")
+            return None
+
+        operator_lineage = self.get_openlineage_facets_on_start() or 
OperatorLineage()
+        hook = self.get_db_hook()
+        try:
+            database_specific_lineage = 
hook.get_openlineage_database_specific_lineage(task_instance)
+        except AttributeError:
+            self.log.debug("%s has no database specific lineage provided", 
hook)
+            database_specific_lineage = None
+
+        if database_specific_lineage is None:
+            return operator_lineage
+
+        return OperatorLineage(
+            inputs=operator_lineage.inputs + database_specific_lineage.inputs,
+            outputs=operator_lineage.outputs + 
database_specific_lineage.outputs,
+            run_facets=merge_dicts(operator_lineage.run_facets, 
database_specific_lineage.run_facets),
+            job_facets=merge_dicts(operator_lineage.job_facets, 
database_specific_lineage.job_facets),
+        )
+
 
 class SQLExecuteQueryOperator(BaseSQLOperator):
     """
@@ -343,76 +423,6 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
         if isinstance(self.parameters, str):
             self.parameters = ast.literal_eval(self.parameters)
 
-    def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
-        try:
-            from airflow.providers.openlineage.sqlparser import SQLParser
-        except ImportError:
-            return None
-
-        hook = self.get_db_hook()
-
-        try:
-            from airflow.providers.openlineage.utils.utils import 
should_use_external_connection
-
-            use_external_connection = should_use_external_connection(hook)
-        except ImportError:
-            # OpenLineage provider release < 1.8.0 - we always use connection
-            use_external_connection = True
-
-        connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
-        try:
-            database_info = hook.get_openlineage_database_info(connection)
-        except AttributeError:
-            self.log.debug("%s has no database info provided", hook)
-            database_info = None
-
-        if database_info is None:
-            return None
-
-        try:
-            sql_parser = SQLParser(
-                dialect=hook.get_openlineage_database_dialect(connection),
-                default_schema=hook.get_openlineage_default_schema(),
-            )
-        except AttributeError:
-            self.log.debug("%s failed to get database dialect", hook)
-            return None
-
-        operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
-            sql=self.sql,
-            hook=hook,
-            database_info=database_info,
-            database=self.database,
-            sqlalchemy_engine=hook.get_sqlalchemy_engine(),
-            use_connection=use_external_connection,
-        )
-
-        return operator_lineage
-
-    def get_openlineage_facets_on_complete(self, task_instance) -> 
OperatorLineage | None:
-        try:
-            from airflow.providers.openlineage.extractors import 
OperatorLineage
-        except ImportError:
-            return None
-
-        operator_lineage = self.get_openlineage_facets_on_start() or 
OperatorLineage()
-
-        hook = self.get_db_hook()
-        try:
-            database_specific_lineage = 
hook.get_openlineage_database_specific_lineage(task_instance)
-        except AttributeError:
-            database_specific_lineage = None
-
-        if database_specific_lineage is None:
-            return operator_lineage
-
-        return OperatorLineage(
-            inputs=operator_lineage.inputs + database_specific_lineage.inputs,
-            outputs=operator_lineage.outputs + 
database_specific_lineage.outputs,
-            run_facets=merge_dicts(operator_lineage.run_facets, 
database_specific_lineage.run_facets),
-            job_facets=merge_dicts(operator_lineage.job_facets, 
database_specific_lineage.job_facets),
-        )
-
 
 class SQLColumnCheckOperator(BaseSQLOperator):
     """
@@ -999,8 +1009,13 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
 
         self.sql1 = f"{sqlt}'{{{{ ds }}}}'"
         self.sql2 = f"{sqlt}'{{{{ macros.ds_add(ds, {self.days_back}) }}}}'"
+        # Save all queries as `sql` attr - similar to other sql operators (to 
be used by listeners).
+        self.sql: list[str] = [self.sql1, self.sql2]
 
     def execute(self, context: Context):
+        # Re-set with templated queries
+        self.sql = [self.sql1, self.sql2]
+
         hook = self.get_db_hook()
         self.log.info("Using ratio formula: %s", self.ratio_formula)
         self.log.info("Executing SQL check: %s", self.sql2)
@@ -1017,25 +1032,36 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
         reference = dict(zip(self.metrics_sorted, row2))
 
         ratios: dict[str, int | None] = {}
-        test_results = {}
+        # Save all details about all tests to be used in error message if 
needed
+        all_tests_results: dict[str, dict[str, Any]] = {}
 
         for metric in self.metrics_sorted:
             cur = current[metric]
             ref = reference[metric]
             threshold = self.metrics_thresholds[metric]
+            single_metric_results = {
+                "metric": metric,
+                "current_metric": cur,
+                "past_metric": ref,
+                "threshold": threshold,
+                "ignore_zero": self.ignore_zero,
+            }
             if cur == 0 or ref == 0:
                 ratios[metric] = None
-                test_results[metric] = self.ignore_zero
+                single_metric_results["ratio"] = None
+                single_metric_results["success"] = self.ignore_zero
             else:
                 ratio_metric = 
self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
                 ratios[metric] = ratio_metric
+                single_metric_results["ratio"] = ratio_metric
                 if ratio_metric is not None:
-                    test_results[metric] = ratio_metric < threshold
+                    single_metric_results["success"] = ratio_metric < threshold
                 else:
-                    test_results[metric] = self.ignore_zero
+                    single_metric_results["success"] = self.ignore_zero
 
+            all_tests_results[metric] = single_metric_results
             self.log.info(
-                ("Current metric for %s: %s\nPast metric for %s: %s\nRatio for 
%s: %s\nThreshold: %s\n"),
+                "Current metric for %s: %s\nPast metric for %s: %s\nRatio for 
%s: %s\nThreshold: %s\n",
                 metric,
                 cur,
                 metric,
@@ -1045,21 +1071,24 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
                 threshold,
             )
 
-        if not all(test_results.values()):
-            failed_tests = [it[0] for it in test_results.items() if not it[1]]
+        failed_tests = [single for single in all_tests_results.values() if not 
single["success"]]
+        if failed_tests:
             self.log.warning(
                 "The following %s tests out of %s failed:",
                 len(failed_tests),
                 len(self.metrics_sorted),
             )
-            for k in failed_tests:
+            for single_filed_test in failed_tests:
                 self.log.warning(
                     "'%s' check failed. %s is above %s",
-                    k,
-                    ratios[k],
-                    self.metrics_thresholds[k],
+                    single_filed_test["metric"],
+                    single_filed_test["ratio"],
+                    single_filed_test["threshold"],
                 )
-            self._raise_exception(f"The following tests have failed:\n {', 
'.join(sorted(failed_tests))}")
+            failed_test_details = "; ".join(
+                f"{t['metric']}: {t}" for t in sorted(failed_tests, key=lambda 
x: x["metric"])
+            )
+            self._raise_exception(f"The following tests have failed:\n 
{failed_test_details}")
 
         self.log.info("All tests have passed")
 
@@ -1206,6 +1235,8 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
         self.parameters = parameters
         self.follow_task_ids_if_true = follow_task_ids_if_true
         self.follow_task_ids_if_false = follow_task_ids_if_false
+        # Chosen branch, after evaluating condition, set during execution, to 
be used by listeners
+        self.follow_branch: list[str] | None = None
 
     def execute(self, context: Context):
         self.log.info(
@@ -1232,32 +1263,30 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
 
         self.log.info("Query returns %s, type '%s'", query_result, 
type(query_result))
 
-        follow_branch = None
         try:
             if isinstance(query_result, bool):
                 if query_result:
-                    follow_branch = self.follow_task_ids_if_true
+                    self.follow_branch = self.follow_task_ids_if_true
             elif isinstance(query_result, str):
                 # return result is not Boolean, try to convert from String to 
Boolean
                 if _parse_boolean(query_result):
-                    follow_branch = self.follow_task_ids_if_true
+                    self.follow_branch = self.follow_task_ids_if_true
             elif isinstance(query_result, int):
                 if bool(query_result):
-                    follow_branch = self.follow_task_ids_if_true
+                    self.follow_branch = self.follow_task_ids_if_true
             else:
                 raise AirflowException(
                     f"Unexpected query return result '{query_result}' type 
'{type(query_result)}'"
                 )
 
-            if follow_branch is None:
-                follow_branch = self.follow_task_ids_if_false
+            if self.follow_branch is None:
+                self.follow_branch = self.follow_task_ids_if_false
         except ValueError:
             raise AirflowException(
                 f"Unexpected query return result '{query_result}' type 
'{type(query_result)}'"
             )
 
-        # TODO(potiuk) remove the type ignore once we solve provider <-> Task 
SDK relationship
-        self.skip_all_except(context["ti"], follow_branch)
+        self.skip_all_except(context["ti"], self.follow_branch)
 
 
 class SQLInsertRowsOperator(BaseSQLOperator):
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py 
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index 07d590c06dc..e47ff81ddd0 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -939,7 +939,16 @@ class TestIntervalCheckOperator:
             ignore_zero=True,
         )
 
-        with pytest.raises(AirflowException, match="f0, f1, f2"):
+        expected_err_message = (
+            "The following tests have failed:\n "
+            "f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 2, 
'threshold': 1.0,"
+            " 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
+            "f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 2, 
'threshold': 1.5,"
+            " 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
+            "f2: {'metric': 'f2', 'current_metric': 1, 'past_metric': 2, 
'threshold': 2.0,"
+            " 'ignore_zero': True, 'ratio': 2.0, 'success': False}"
+        )
+        with pytest.raises(AirflowException, match=expected_err_message):
             operator.execute(context=MagicMock())
 
     @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
@@ -969,7 +978,14 @@ class TestIntervalCheckOperator:
             ignore_zero=True,
         )
 
-        with pytest.raises(AirflowException, match="f0, f1"):
+        expected_err_message = (
+            "The following tests have failed:\n "
+            "f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 3, 
'threshold': 0.5, "
+            "'ignore_zero': True, 'ratio': 0.6666666666666666, 'success': 
False}; "
+            "f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 3, 
'threshold': 0.6, "
+            "'ignore_zero': True, 'ratio': 0.6666666666666666, 'success': 
False}"
+        )
+        with pytest.raises(AirflowException, match=expected_err_message):
             operator.execute(context=MagicMock())
 
 

Reply via email to