This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 29ace334ac9 chore: Move OpenLineage methods to BaseSQLOperator (#58897)
29ace334ac9 is described below
commit 29ace334ac9708fd32183bd3b6c115e503035e7e
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Dec 2 00:34:14 2025 +0100
chore: Move OpenLineage methods to BaseSQLOperator (#58897)
---
.../airflow/providers/common/sql/operators/sql.py | 209 ++++++++++++---------
.../tests/unit/common/sql/operators/test_sql.py | 20 +-
2 files changed, 137 insertions(+), 92 deletions(-)
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index 1e82a01bd1e..3132fcf1cbc 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -213,6 +213,86 @@ class BaseSQLOperator(BaseOperator):
raise AirflowException(exception_string)
raise AirflowFailException(exception_string)
+ def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
+ """Generate OpenLineage facets on start for SQL operators."""
+ try:
+ from airflow.providers.openlineage.extractors import
OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+ except ImportError:
+ self.log.debug("OpenLineage could not import required classes.
Skipping.")
+ return None
+
+ sql = getattr(self, "sql", None)
+ if not sql:
+ self.log.debug("OpenLineage could not find 'sql' attribute on
`%s`.", type(self).__name__)
+ return OperatorLineage()
+
+ hook = self.get_db_hook()
+ try:
+ from airflow.providers.openlineage.utils.utils import
should_use_external_connection
+
+ use_external_connection = should_use_external_connection(hook)
+ except ImportError:
+ # OpenLineage provider release < 1.8.0 - we always use connection
+ use_external_connection = True
+
+ connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
+ try:
+ database_info = hook.get_openlineage_database_info(connection)
+ except AttributeError:
+ self.log.debug("%s has no database info provided", hook)
+ database_info = None
+
+ if database_info is None:
+ self.log.debug("OpenLineage could not retrieve database
information. Skipping.")
+ return OperatorLineage()
+
+ try:
+ sql_parser = SQLParser(
+ dialect=hook.get_openlineage_database_dialect(connection),
+ default_schema=hook.get_openlineage_default_schema(),
+ )
+ except AttributeError:
+ self.log.debug("%s failed to get database dialect", hook)
+ return None
+
+ operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
+ sql=sql,
+ hook=hook,
+ database_info=database_info,
+ database=self.database,
+ sqlalchemy_engine=hook.get_sqlalchemy_engine(),
+ use_connection=use_external_connection,
+ )
+
+ return operator_lineage
+
+ def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage | None:
+ """Generate OpenLineage facets when task completes."""
+ try:
+ from airflow.providers.openlineage.extractors import
OperatorLineage
+ except ImportError:
+ self.log.debug("OpenLineage could not import required classes.
Skipping.")
+ return None
+
+ operator_lineage = self.get_openlineage_facets_on_start() or
OperatorLineage()
+ hook = self.get_db_hook()
+ try:
+ database_specific_lineage =
hook.get_openlineage_database_specific_lineage(task_instance)
+ except AttributeError:
+ self.log.debug("%s has no database specific lineage provided",
hook)
+ database_specific_lineage = None
+
+ if database_specific_lineage is None:
+ return operator_lineage
+
+ return OperatorLineage(
+ inputs=operator_lineage.inputs + database_specific_lineage.inputs,
+ outputs=operator_lineage.outputs +
database_specific_lineage.outputs,
+ run_facets=merge_dicts(operator_lineage.run_facets,
database_specific_lineage.run_facets),
+ job_facets=merge_dicts(operator_lineage.job_facets,
database_specific_lineage.job_facets),
+ )
+
class SQLExecuteQueryOperator(BaseSQLOperator):
"""
@@ -343,76 +423,6 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
if isinstance(self.parameters, str):
self.parameters = ast.literal_eval(self.parameters)
- def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
- try:
- from airflow.providers.openlineage.sqlparser import SQLParser
- except ImportError:
- return None
-
- hook = self.get_db_hook()
-
- try:
- from airflow.providers.openlineage.utils.utils import
should_use_external_connection
-
- use_external_connection = should_use_external_connection(hook)
- except ImportError:
- # OpenLineage provider release < 1.8.0 - we always use connection
- use_external_connection = True
-
- connection = hook.get_connection(getattr(hook, hook.conn_name_attr))
- try:
- database_info = hook.get_openlineage_database_info(connection)
- except AttributeError:
- self.log.debug("%s has no database info provided", hook)
- database_info = None
-
- if database_info is None:
- return None
-
- try:
- sql_parser = SQLParser(
- dialect=hook.get_openlineage_database_dialect(connection),
- default_schema=hook.get_openlineage_default_schema(),
- )
- except AttributeError:
- self.log.debug("%s failed to get database dialect", hook)
- return None
-
- operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
- sql=self.sql,
- hook=hook,
- database_info=database_info,
- database=self.database,
- sqlalchemy_engine=hook.get_sqlalchemy_engine(),
- use_connection=use_external_connection,
- )
-
- return operator_lineage
-
- def get_openlineage_facets_on_complete(self, task_instance) ->
OperatorLineage | None:
- try:
- from airflow.providers.openlineage.extractors import
OperatorLineage
- except ImportError:
- return None
-
- operator_lineage = self.get_openlineage_facets_on_start() or
OperatorLineage()
-
- hook = self.get_db_hook()
- try:
- database_specific_lineage =
hook.get_openlineage_database_specific_lineage(task_instance)
- except AttributeError:
- database_specific_lineage = None
-
- if database_specific_lineage is None:
- return operator_lineage
-
- return OperatorLineage(
- inputs=operator_lineage.inputs + database_specific_lineage.inputs,
- outputs=operator_lineage.outputs +
database_specific_lineage.outputs,
- run_facets=merge_dicts(operator_lineage.run_facets,
database_specific_lineage.run_facets),
- job_facets=merge_dicts(operator_lineage.job_facets,
database_specific_lineage.job_facets),
- )
-
class SQLColumnCheckOperator(BaseSQLOperator):
"""
@@ -999,8 +1009,13 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
self.sql1 = f"{sqlt}'{{{{ ds }}}}'"
self.sql2 = f"{sqlt}'{{{{ macros.ds_add(ds, {self.days_back}) }}}}'"
+ # Save all queries as `sql` attr - similar to other sql operators (to
be used by listeners).
+ self.sql: list[str] = [self.sql1, self.sql2]
def execute(self, context: Context):
+ # Re-set with templated queries
+ self.sql = [self.sql1, self.sql2]
+
hook = self.get_db_hook()
self.log.info("Using ratio formula: %s", self.ratio_formula)
self.log.info("Executing SQL check: %s", self.sql2)
@@ -1017,25 +1032,36 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
reference = dict(zip(self.metrics_sorted, row2))
ratios: dict[str, int | None] = {}
- test_results = {}
+ # Save all details about all tests to be used in error message if
needed
+ all_tests_results: dict[str, dict[str, Any]] = {}
for metric in self.metrics_sorted:
cur = current[metric]
ref = reference[metric]
threshold = self.metrics_thresholds[metric]
+ single_metric_results = {
+ "metric": metric,
+ "current_metric": cur,
+ "past_metric": ref,
+ "threshold": threshold,
+ "ignore_zero": self.ignore_zero,
+ }
if cur == 0 or ref == 0:
ratios[metric] = None
- test_results[metric] = self.ignore_zero
+ single_metric_results["ratio"] = None
+ single_metric_results["success"] = self.ignore_zero
else:
ratio_metric =
self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
ratios[metric] = ratio_metric
+ single_metric_results["ratio"] = ratio_metric
if ratio_metric is not None:
- test_results[metric] = ratio_metric < threshold
+ single_metric_results["success"] = ratio_metric < threshold
else:
- test_results[metric] = self.ignore_zero
+ single_metric_results["success"] = self.ignore_zero
+ all_tests_results[metric] = single_metric_results
self.log.info(
- ("Current metric for %s: %s\nPast metric for %s: %s\nRatio for
%s: %s\nThreshold: %s\n"),
+ "Current metric for %s: %s\nPast metric for %s: %s\nRatio for
%s: %s\nThreshold: %s\n",
metric,
cur,
metric,
@@ -1045,21 +1071,24 @@ class SQLIntervalCheckOperator(BaseSQLOperator):
threshold,
)
- if not all(test_results.values()):
- failed_tests = [it[0] for it in test_results.items() if not it[1]]
+ failed_tests = [single for single in all_tests_results.values() if not
single["success"]]
+ if failed_tests:
self.log.warning(
"The following %s tests out of %s failed:",
len(failed_tests),
len(self.metrics_sorted),
)
- for k in failed_tests:
+ for single_filed_test in failed_tests:
self.log.warning(
"'%s' check failed. %s is above %s",
- k,
- ratios[k],
- self.metrics_thresholds[k],
+ single_filed_test["metric"],
+ single_filed_test["ratio"],
+ single_filed_test["threshold"],
)
- self._raise_exception(f"The following tests have failed:\n {',
'.join(sorted(failed_tests))}")
+ failed_test_details = "; ".join(
+ f"{t['metric']}: {t}" for t in sorted(failed_tests, key=lambda
x: x["metric"])
+ )
+ self._raise_exception(f"The following tests have failed:\n
{failed_test_details}")
self.log.info("All tests have passed")
@@ -1206,6 +1235,8 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
self.parameters = parameters
self.follow_task_ids_if_true = follow_task_ids_if_true
self.follow_task_ids_if_false = follow_task_ids_if_false
+ # Chosen branch, after evaluating condition, set during execution, to
be used by listeners
+ self.follow_branch: list[str] | None = None
def execute(self, context: Context):
self.log.info(
@@ -1232,32 +1263,30 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
self.log.info("Query returns %s, type '%s'", query_result,
type(query_result))
- follow_branch = None
try:
if isinstance(query_result, bool):
if query_result:
- follow_branch = self.follow_task_ids_if_true
+ self.follow_branch = self.follow_task_ids_if_true
elif isinstance(query_result, str):
# return result is not Boolean, try to convert from String to
Boolean
if _parse_boolean(query_result):
- follow_branch = self.follow_task_ids_if_true
+ self.follow_branch = self.follow_task_ids_if_true
elif isinstance(query_result, int):
if bool(query_result):
- follow_branch = self.follow_task_ids_if_true
+ self.follow_branch = self.follow_task_ids_if_true
else:
raise AirflowException(
f"Unexpected query return result '{query_result}' type
'{type(query_result)}'"
)
- if follow_branch is None:
- follow_branch = self.follow_task_ids_if_false
+ if self.follow_branch is None:
+ self.follow_branch = self.follow_task_ids_if_false
except ValueError:
raise AirflowException(
f"Unexpected query return result '{query_result}' type
'{type(query_result)}'"
)
- # TODO(potiuk) remove the type ignore once we solve provider <-> Task
SDK relationship
- self.skip_all_except(context["ti"], follow_branch)
+ self.skip_all_except(context["ti"], self.follow_branch)
class SQLInsertRowsOperator(BaseSQLOperator):
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index 07d590c06dc..e47ff81ddd0 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -939,7 +939,16 @@ class TestIntervalCheckOperator:
ignore_zero=True,
)
- with pytest.raises(AirflowException, match="f0, f1, f2"):
+ expected_err_message = (
+ "The following tests have failed:\n "
+ "f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 2,
'threshold': 1.0,"
+ " 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
+ "f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 2,
'threshold': 1.5,"
+ " 'ignore_zero': True, 'ratio': 2.0, 'success': False}; "
+ "f2: {'metric': 'f2', 'current_metric': 1, 'past_metric': 2,
'threshold': 2.0,"
+ " 'ignore_zero': True, 'ratio': 2.0, 'success': False}"
+ )
+ with pytest.raises(AirflowException, match=expected_err_message):
operator.execute(context=MagicMock())
@mock.patch.object(SQLIntervalCheckOperator, "get_db_hook")
@@ -969,7 +978,14 @@ class TestIntervalCheckOperator:
ignore_zero=True,
)
- with pytest.raises(AirflowException, match="f0, f1"):
+ expected_err_message = (
+ "The following tests have failed:\n "
+ "f0: {'metric': 'f0', 'current_metric': 1, 'past_metric': 3,
'threshold': 0.5, "
+ "'ignore_zero': True, 'ratio': 0.6666666666666666, 'success':
False}; "
+ "f1: {'metric': 'f1', 'current_metric': 1, 'past_metric': 3,
'threshold': 0.6, "
+ "'ignore_zero': True, 'ratio': 0.6666666666666666, 'success':
False}"
+ )
+ with pytest.raises(AirflowException, match=expected_err_message):
operator.execute(context=MagicMock())