This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1708da9233 Adding configurable fetch_all_handler for JdbcOperator
(#25412)
1708da9233 is described below
commit 1708da9233c13c3821d76e56dbe0e383ff67b0fd
Author: Dmytro Kazanzhy <[email protected]>
AuthorDate: Sun Aug 7 12:18:21 2022 +0300
Adding configurable fetch_all_handler for JdbcOperator (#25412)
---
airflow/providers/jdbc/operators/jdbc.py | 9 +++++++--
tests/providers/jdbc/operators/test_jdbc.py | 16 ++++++++++++++--
2 files changed, 21 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/jdbc/operators/jdbc.py
b/airflow/providers/jdbc/operators/jdbc.py
index 6b38366b41..f45d112c43 100644
--- a/airflow/providers/jdbc/operators/jdbc.py
+++ b/airflow/providers/jdbc/operators/jdbc.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional,
Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
@@ -57,6 +57,7 @@ class JdbcOperator(BaseOperator):
jdbc_conn_id: str = 'jdbc_default',
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
+ handler: Callable[[Any], Any] = fetch_all_handler,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -64,9 +65,13 @@ class JdbcOperator(BaseOperator):
self.sql = sql
self.jdbc_conn_id = jdbc_conn_id
self.autocommit = autocommit
+ self.handler = handler
self.hook = None
def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id)
- return hook.run(self.sql, self.autocommit, parameters=self.parameters,
handler=fetch_all_handler)
+ if self.do_xcom_push:
+ return hook.run(self.sql, self.autocommit,
parameters=self.parameters, handler=self.handler)
+ else:
+ return hook.run(self.sql, self.autocommit,
parameters=self.parameters)
diff --git a/tests/providers/jdbc/operators/test_jdbc.py
b/tests/providers/jdbc/operators/test_jdbc.py
index 9168674c56..7b40e48340 100644
--- a/tests/providers/jdbc/operators/test_jdbc.py
+++ b/tests/providers/jdbc/operators/test_jdbc.py
@@ -28,8 +28,8 @@ class TestJdbcOperator(unittest.TestCase):
self.kwargs = dict(sql='sql', task_id='test_jdbc_operator', dag=None)
@patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
- def test_execute(self, mock_jdbc_hook):
- jdbc_operator = JdbcOperator(**self.kwargs)
+ def test_execute_do_push(self, mock_jdbc_hook):
+ jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=True)
jdbc_operator.execute(context={})
mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
@@ -39,3 +39,15 @@ class TestJdbcOperator(unittest.TestCase):
parameters=jdbc_operator.parameters,
handler=fetch_all_handler,
)
+
+ @patch('airflow.providers.jdbc.operators.jdbc.JdbcHook')
+ def test_execute_dont_push(self, mock_jdbc_hook):
+ jdbc_operator = JdbcOperator(**self.kwargs, do_xcom_push=False)
+ jdbc_operator.execute(context={})
+
+
mock_jdbc_hook.assert_called_once_with(jdbc_conn_id=jdbc_operator.jdbc_conn_id)
+ mock_jdbc_hook.return_value.run.assert_called_once_with(
+ jdbc_operator.sql,
+ jdbc_operator.autocommit,
+ parameters=jdbc_operator.parameters,
+ )