This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1708da9233 Adding configurable fetch_all_handler for JdbcOperator 
(#25412)
1708da9233 is described below

commit 1708da9233c13c3821d76e56dbe0e383ff67b0fd
Author: Dmytro Kazanzhy <[email protected]>
AuthorDate: Sun Aug 7 12:18:21 2022 +0300

    Adding configurable fetch_all_handler for JdbcOperator (#25412)
---
 airflow/providers/jdbc/operators/jdbc.py    |  9 +++++++--
 tests/providers/jdbc/operators/test_jdbc.py | 16 ++++++++++++++--
 2 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/jdbc/operators/jdbc.py 
b/airflow/providers/jdbc/operators/jdbc.py
index 6b38366b41..f45d112c43 100644
--- a/airflow/providers/jdbc/operators/jdbc.py
+++ b/airflow/providers/jdbc/operators/jdbc.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, 
Sequence, Union
 
 from airflow.models import BaseOperator
 from airflow.providers.common.sql.hooks.sql import fetch_all_handler
@@ -57,6 +57,7 @@ class JdbcOperator(BaseOperator):
         jdbc_conn_id: str = 'jdbc_default',
         autocommit: bool = False,
         parameters: Optional[Union[Iterable, Mapping]] = None,
+        handler: Callable[[Any], Any] = fetch_all_handler,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -64,9 +65,13 @@ class JdbcOperator(BaseOperator):
         self.sql = sql
         self.jdbc_conn_id = jdbc_conn_id
         self.autocommit = autocommit
+        self.handler = handler
         self.hook = None
 
     def execute(self, context: 'Context'):
         self.log.info('Executing: %s', self.sql)
         hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id)
-        return hook.run(self.sql, self.autocommit, parameters=self.parameters, 
handler=fetch_all_handler)
+        if self.do_xcom_push:
+            return hook.run(self.sql, self.autocommit, 
parameters=self.parameters, handler=self.handler)
+        else:
+            return hook.run(self.sql, self.autocommit, 
parameters=self.parameters)
diff --git a/tests/providers/jdbc/operators/test_jdbc.py 
b/tests/providers/jdbc/operators/test_jdbc.py
index 9168674c56..7b40e48340 100644
--- a/tests/providers/jdbc/operators/test_jdbc.py
+++ b/tests/providers/jdbc/operators/test_jdbc.py
@@ -28,8 +28,8 @@ class TestJdbcOperator(unittest.TestCase):
         self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None)
 
     @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
-    def test_execute(self, mock_jdbc_hook):
-        jdbc_operator = JdbcOperator(**self.kwargs)
+    def test_execute_do_push(self, mock_jdbc_hook):
+        jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True)
         jdbc_operator.execute(context={})
 
         
mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
@@ -39,3 +39,15 @@ class TestJdbcOperator(unittest.TestCase):
             parameters=jdbc_operator.parameters,
             handler=fetch_all_handler,
         )
+
+    @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
+    def test_execute_dont_push(self, mock_jdbc_hook):
+        jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False)
+        jdbc_operator.execute(context={})
+
+        
mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
+        mock_jdbc_hook.return_value.run.assert_called_once_with(
+            jdbc_operator.sql,
+            jdbc_operator.autocommit,
+            parameters=jdbc_operator.parameters,
+        )

Reply via email to