This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new edd7133a13 Add conditional output processing in SQL operators (#31136)
edd7133a13 is described below

commit edd7133a1336c9553d77ba13c83bc7f48d4c63f0
Author: Jarek Potiuk <[email protected]>
AuthorDate: Tue May 9 13:11:41 2023 +0200

    Add conditional output processing in SQL operators (#31136)
    
    The change adds conditional processing of output based on
    criteria that can be overridden by the operator extending the
    common.sql BaseSQLOperator. Originally, output processing has only
    been happening if "do_xcom_push" was enabled, but in some cases
    we want to run processing also when do_xcom_push is disabled
    (for example in case of databricks SQL operator, it might be
    done when the output is redirected to a file).
    
    This change enables it.
    
    Fixes: #31080
---
 airflow/providers/common/sql/operators/sql.py      |  7 +++++--
 airflow/providers/common/sql/provider.yaml         |  1 +
 .../databricks/operators/databricks_sql.py         |  3 +++
 airflow/providers/databricks/provider.yaml         |  2 +-
 generated/provider_dependencies.json               |  2 +-
 .../databricks/operators/test_databricks_sql.py    | 23 ++++++++++++++++++++--
 6 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index eef4ba4e67..723afe6b8a 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -258,6 +258,9 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             self.log.info("Operator output is: %s", results)
         return results
 
+    def _should_run_output_processing(self) -> bool:
+        return self.do_xcom_push
+
     def execute(self, context):
         self.log.info("Executing: %s", self.sql)
         hook = self.get_db_hook()
@@ -269,11 +272,11 @@ class SQLExecuteQueryOperator(BaseSQLOperator):
             sql=self.sql,
             autocommit=self.autocommit,
             parameters=self.parameters,
-            handler=self.handler if self.do_xcom_push else None,
+            handler=self.handler if self._should_run_output_processing() else 
None,
             return_last=self.return_last,
             **extra_kwargs,
         )
-        if not self.do_xcom_push:
+        if not self._should_run_output_processing():
             return None
         if return_single_query_results(self.sql, self.return_last, 
self.split_statements):
             # For simplicity, we pass always list as input to _process_output, 
regardless if
diff --git a/airflow/providers/common/sql/provider.yaml 
b/airflow/providers/common/sql/provider.yaml
index 2aedf9752e..b5d7d1931c 100644
--- a/airflow/providers/common/sql/provider.yaml
+++ b/airflow/providers/common/sql/provider.yaml
@@ -23,6 +23,7 @@ description: |
 
 suspended: false
 versions:
+  - 1.5.0
   - 1.4.0
   - 1.3.4
   - 1.3.3
diff --git a/airflow/providers/databricks/operators/databricks_sql.py 
b/airflow/providers/databricks/operators/databricks_sql.py
index 178afc8d98..4a708ec6c8 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -120,6 +120,9 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
         }
         return DatabricksSqlHook(self.databricks_conn_id, **hook_params)
 
+    def _should_run_output_processing(self) -> bool:
+        return self.do_xcom_push or bool(self._output_path)
+
     def _process_output(self, results: list[Any], descriptions: 
list[Sequence[Sequence] | None]) -> list[Any]:
         if not self._output_path:
             return list(zip(descriptions, results))
diff --git a/airflow/providers/databricks/provider.yaml 
b/airflow/providers/databricks/provider.yaml
index a45518f2f4..0b14851140 100644
--- a/airflow/providers/databricks/provider.yaml
+++ b/airflow/providers/databricks/provider.yaml
@@ -46,7 +46,7 @@ versions:
 
 dependencies:
   - apache-airflow>=2.4.0
-  - apache-airflow-providers-common-sql>=1.3.1
+  - apache-airflow-providers-common-sql>=1.5.0
   - requests>=2.27,<3
   - databricks-sql-connector>=2.0.0, <3.0.0
   - aiohttp>=3.6.3, <4
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 8626eee899..1f1a8e1f5c 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -237,7 +237,7 @@
   "databricks": {
     "deps": [
       "aiohttp>=3.6.3, <4",
-      "apache-airflow-providers-common-sql>=1.3.1",
+      "apache-airflow-providers-common-sql>=1.5.0",
       "apache-airflow>=2.4.0",
       "databricks-sql-connector>=2.0.0, <3.0.0",
       "requests>=2.27,<3"
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py 
b/tests/providers/databricks/operators/test_databricks_sql.py
index 8489f45095..dd0c9b0187 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -152,7 +152,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
 
 
 @pytest.mark.parametrize(
-    "return_last, split_statements, sql, descriptions, hook_results",
+    "return_last, split_statements, sql, descriptions, hook_results, 
do_xcom_push",
     [
         pytest.param(
             True,
@@ -160,6 +160,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy",
             [[("id",), ("value",)]],
             [Row(id=1, value="value1"), Row(id=2, value="value2")],
+            True,
             id="Scalar: return_last True and split_statement  False",
         ),
         pytest.param(
@@ -168,6 +169,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy",
             [[("id",), ("value",)]],
             [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+            True,
             id="Non-Scalar: return_last False and split_statement True",
         ),
         pytest.param(
@@ -176,6 +178,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy",
             [[("id",), ("value",)]],
             [Row(id=1, value="value1"), Row(id=2, value="value2")],
+            True,
             id="Scalar: return_last True and no split_statement True",
         ),
         pytest.param(
@@ -184,6 +187,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy",
             [[("id",), ("value",)]],
             [Row(id=1, value="value1"), Row(id=2, value="value2")],
+            True,
             id="Scalar: return_last False and split_statement is False",
         ),
         pytest.param(
@@ -195,6 +199,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
                 [Row(id2=1, value2="value1"), Row(id2=2, value2="value2")],
                 [Row(id=1, value="value1"), Row(id=2, value="value2")],
             ],
+            True,
             id="Non-Scalar: return_last False and split_statement is True",
         ),
         pytest.param(
@@ -203,6 +208,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy2; select * from dummy",
             [[("id2",), ("value2",)], [("id",), ("value",)]],
             [Row(id=1, value="value1"), Row(id=2, value="value2")],
+            True,
             id="Scalar: return_last True and split_statement is True",
         ),
         pytest.param(
@@ -211,6 +217,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             "select * from dummy2; select * from dummy",
             [[("id2",), ("value2",)], [("id",), ("value",)]],
             [Row(id=1, value="value1"), Row(id=2, value="value2")],
+            True,
             id="Scalar: return_last True and split_statement is True",
         ),
         pytest.param(
@@ -219,6 +226,7 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             ["select * from dummy2", "select * from dummy"],
             [[("id2",), ("value2",)], [("id",), ("value",)]],
             [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+            True,
             id="Non-Scalar: sql is list and return_last is True",
         ),
         pytest.param(
@@ -227,11 +235,21 @@ def test_exec_success(sql, return_last, split_statement, 
hook_results, hook_desc
             ["select * from dummy2", "select * from dummy"],
             [[("id2",), ("value2",)], [("id",), ("value",)]],
             [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+            True,
             id="Non-Scalar: sql is list and return_last is False",
         ),
+        pytest.param(
+            False,
+            True,
+            ["select * from dummy2", "select * from dummy"],
+            [[("id2",), ("value2",)], [("id",), ("value",)]],
+            [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+            False,
+            id="Write output when do_xcom_push is False",
+        ),
     ],
 )
-def test_exec_write_file(return_last, split_statements, sql, descriptions, 
hook_results):
+def test_exec_write_file(return_last, split_statements, sql, descriptions, 
hook_results, do_xcom_push):
     """
     Test the execute function in case where SQL query was successful and data 
is written as CSV
     """
@@ -242,6 +260,7 @@ def test_exec_write_file(return_last, split_statements, 
sql, descriptions, hook_
             sql=sql,
             output_path=tempfile_path,
             return_last=return_last,
+            do_xcom_push=do_xcom_push,
             split_statements=split_statements,
         )
         db_mock = db_mock_class.return_value

Reply via email to