This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new edd7133a13 Add conditional output processing in SQL operators (#31136)
edd7133a13 is described below
commit edd7133a1336c9553d77ba13c83bc7f48d4c63f0
Author: Jarek Potiuk <[email protected]>
AuthorDate: Tue May 9 13:11:41 2023 +0200
Add conditional output processing in SQL operators (#31136)
The change adds conditional processing of output based on
criteria that can be overridden by the operator extending the
common.sql BaseSQLOperator. Originally, output processing has only
been happening if "do_xcom_push" was enabled, but in some cases
we want to run processing also when do_xcom_push is disabled
(for example in case of databricks SQL operator, it might be
done when the output is redirected to a file).
This change enables it.
Fixes: #31080
---
airflow/providers/common/sql/operators/sql.py | 7 +++++--
airflow/providers/common/sql/provider.yaml | 1 +
.../databricks/operators/databricks_sql.py | 3 +++
airflow/providers/databricks/provider.yaml | 2 +-
generated/provider_dependencies.json | 2 +-
.../databricks/operators/test_databricks_sql.py | 23 ++++++++++++++++++++--
6 files changed, 32 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/common/sql/operators/sql.py
b/airflow/providers/common/sql/operators/sql.py
index eef4ba4e67..723afe6b8a 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -258,6 +258,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
self.log.info("Operator output is: %s", results)
return results
+ def _should_run_output_processing(self) -> bool:
+ return self.do_xcom_push
+
def execute(self, context):
self.log.info("Executing: %s", self.sql)
hook = self.get_db_hook()
@@ -269,11 +272,11 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
sql=self.sql,
autocommit=self.autocommit,
parameters=self.parameters,
- handler=self.handler if self.do_xcom_push else None,
+ handler=self.handler if self._should_run_output_processing() else
None,
return_last=self.return_last,
**extra_kwargs,
)
- if not self.do_xcom_push:
+ if not self._should_run_output_processing():
return None
if return_single_query_results(self.sql, self.return_last,
self.split_statements):
# For simplicity, we pass always list as input to _process_output,
regardless if
diff --git a/airflow/providers/common/sql/provider.yaml
b/airflow/providers/common/sql/provider.yaml
index 2aedf9752e..b5d7d1931c 100644
--- a/airflow/providers/common/sql/provider.yaml
+++ b/airflow/providers/common/sql/provider.yaml
@@ -23,6 +23,7 @@ description: |
suspended: false
versions:
+ - 1.5.0
- 1.4.0
- 1.3.4
- 1.3.3
diff --git a/airflow/providers/databricks/operators/databricks_sql.py
b/airflow/providers/databricks/operators/databricks_sql.py
index 178afc8d98..4a708ec6c8 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -120,6 +120,9 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
}
return DatabricksSqlHook(self.databricks_conn_id, **hook_params)
+ def _should_run_output_processing(self) -> bool:
+ return self.do_xcom_push or bool(self._output_path)
+
def _process_output(self, results: list[Any], descriptions:
list[Sequence[Sequence] | None]) -> list[Any]:
if not self._output_path:
return list(zip(descriptions, results))
diff --git a/airflow/providers/databricks/provider.yaml
b/airflow/providers/databricks/provider.yaml
index a45518f2f4..0b14851140 100644
--- a/airflow/providers/databricks/provider.yaml
+++ b/airflow/providers/databricks/provider.yaml
@@ -46,7 +46,7 @@ versions:
dependencies:
- apache-airflow>=2.4.0
- - apache-airflow-providers-common-sql>=1.3.1
+ - apache-airflow-providers-common-sql>=1.5.0
- requests>=2.27,<3
- databricks-sql-connector>=2.0.0, <3.0.0
- aiohttp>=3.6.3, <4
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 8626eee899..1f1a8e1f5c 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -237,7 +237,7 @@
"databricks": {
"deps": [
"aiohttp>=3.6.3, <4",
- "apache-airflow-providers-common-sql>=1.3.1",
+ "apache-airflow-providers-common-sql>=1.5.0",
"apache-airflow>=2.4.0",
"databricks-sql-connector>=2.0.0, <3.0.0",
"requests>=2.27,<3"
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py
b/tests/providers/databricks/operators/test_databricks_sql.py
index 8489f45095..dd0c9b0187 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -152,7 +152,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
@pytest.mark.parametrize(
- "return_last, split_statements, sql, descriptions, hook_results",
+ "return_last, split_statements, sql, descriptions, hook_results,
do_xcom_push",
[
pytest.param(
True,
@@ -160,6 +160,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
+ True,
id="Scalar: return_last True and split_statement False",
),
pytest.param(
@@ -168,6 +169,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ True,
id="Non-Scalar: return_last False and split_statement True",
),
pytest.param(
@@ -176,6 +178,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
+ True,
id="Scalar: return_last True and no split_statement True",
),
pytest.param(
@@ -184,6 +187,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy",
[[("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
+ True,
id="Scalar: return_last False and split_statement is False",
),
pytest.param(
@@ -195,6 +199,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
[Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
],
+ True,
id="Non-Scalar: return_last False and split_statement is True",
),
pytest.param(
@@ -203,6 +208,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy2; select * from dummy",
[[("id2",), ("value2",)], [("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
+ True,
id="Scalar: return_last True and split_statement is True",
),
pytest.param(
@@ -211,6 +217,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
"select * from dummy2; select * from dummy",
[[("id2",), ("value2",)], [("id",), ("value",)]],
[Row(id=1, value="value1"), Row(id=2, value="value2")],
+ True,
id="Scalar: return_last True and split_statement is True",
),
pytest.param(
@@ -219,6 +226,7 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
["select * from dummy2", "select * from dummy"],
[[("id2",), ("value2",)], [("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ True,
id="Non-Scalar: sql is list and return_last is True",
),
pytest.param(
@@ -227,11 +235,21 @@ def test_exec_success(sql, return_last, split_statement,
hook_results, hook_desc
["select * from dummy2", "select * from dummy"],
[[("id2",), ("value2",)], [("id",), ("value",)]],
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ True,
id="Non-Scalar: sql is list and return_last is False",
),
+ pytest.param(
+ False,
+ True,
+ ["select * from dummy2", "select * from dummy"],
+ [[("id2",), ("value2",)], [("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ False,
+ id="Write output when do_xcom_push is False",
+ ),
],
)
-def test_exec_write_file(return_last, split_statements, sql, descriptions,
hook_results):
+def test_exec_write_file(return_last, split_statements, sql, descriptions,
hook_results, do_xcom_push):
"""
Test the execute function in case where SQL query was successful and data
is written as CSV
"""
@@ -242,6 +260,7 @@ def test_exec_write_file(return_last, split_statements,
sql, descriptions, hook_
sql=sql,
output_path=tempfile_path,
return_last=return_last,
+ do_xcom_push=do_xcom_push,
split_statements=split_statements,
)
db_mock = db_mock_class.return_value