This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1407e277ae Add `sql_hook_params` parameter to `S3ToSqlOperator` 
(#33427)
1407e277ae is described below

commit 1407e277aeb059cbfd1bb96fb3f43c4bf4f15cea
Author: Alex Begg <[email protected]>
AuthorDate: Sat Aug 19 00:52:40 2023 -0700

    Add `sql_hook_params` parameter to `S3ToSqlOperator` (#33427)
    
    Adding `sql_hook_params` parameter to `SqlToS3Operator`. This will allow 
you to pass extra config params to the underlying SQL hook.
    
    This uses the same "sql_hook_params" parameter name as already used in 
`SqlToSlackOperator`.
---
 airflow/providers/amazon/aws/transfers/s3_to_sql.py    |  7 ++++++-
 tests/providers/amazon/aws/transfers/test_s3_to_sql.py | 17 ++++++++++++++---
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sql.py 
b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
index 8e0613ea6d..667e3c174b 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_sql.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
@@ -44,6 +44,8 @@ class S3ToSqlOperator(BaseOperator):
     :param s3_bucket: reference to a specific S3 bucket
     :param s3_key: reference to a specific S3 key
     :param sql_conn_id: reference to a specific SQL database. Must be of type 
DBApiHook
+    :param sql_hook_params: Extra config params to be passed to the underlying 
hook.
+        Should match the desired hook constructor params.
     :param aws_conn_id: reference to a specific S3 / AWS connection
     :param column_list: list of column names to use in the insert SQL.
     :param commit_every: The maximum number of rows to insert in one
@@ -83,6 +85,7 @@ class S3ToSqlOperator(BaseOperator):
         commit_every: int = 1000,
         schema: str | None = None,
         sql_conn_id: str = "sql_default",
+        sql_hook_params: dict | None = None,
         aws_conn_id: str = "aws_default",
         **kwargs,
     ) -> None:
@@ -96,6 +99,7 @@ class S3ToSqlOperator(BaseOperator):
         self.column_list = column_list
         self.commit_every = commit_every
         self.parser = parser
+        self.sql_hook_params = sql_hook_params
 
     def execute(self, context: Context) -> None:
         self.log.info("Loading %s to SQL table %s...", self.s3_key, self.table)
@@ -120,7 +124,8 @@ class S3ToSqlOperator(BaseOperator):
     @cached_property
     def db_hook(self):
         self.log.debug("Get connection for %s", self.sql_conn_id)
-        hook = BaseHook.get_hook(self.sql_conn_id)
+        conn = BaseHook.get_connection(self.sql_conn_id)
+        hook = conn.get_hook(hook_params=self.sql_hook_params)
         if not callable(getattr(hook, "insert_rows", None)):
             raise AirflowException(
                 "This hook is not supported. The hook class must have an 
`insert_rows` method."
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sql.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
index 0239624034..3e6e6913e7 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
@@ -76,7 +76,7 @@ class TestS3ToSqlTransfer:
         return bad_hook
 
     
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
-    @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook")
+    @patch("airflow.models.connection.Connection.get_hook")
     @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
     def test_execute(self, mock_get_key, mock_hook, mock_tempfile, 
mock_parser):
 
@@ -93,7 +93,7 @@ class TestS3ToSqlTransfer:
 
         
mock_parser.assert_called_once_with(mock_tempfile.return_value.__enter__.return_value.name)
 
-        mock_hook.get_hook.return_value.insert_rows.assert_called_once_with(
+        mock_hook.return_value.insert_rows.assert_called_once_with(
             table=self.s3_to_sql_transfer_kwargs["table"],
             schema=self.s3_to_sql_transfer_kwargs["schema"],
             target_fields=self.s3_to_sql_transfer_kwargs["column_list"],
@@ -102,13 +102,24 @@ class TestS3ToSqlTransfer:
         )
 
     
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
-    
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_hook", 
return_value=mock_bad_hook)
+    @patch("airflow.models.connection.Connection.get_hook", 
return_value=mock_bad_hook)
     @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
     def test_execute_with_bad_hook(self, mock_get_key, mock_bad_hook, 
mock_tempfile, mock_parser):
 
         with pytest.raises(AirflowException):
             S3ToSqlOperator(parser=mock_parser, 
**self.s3_to_sql_transfer_kwargs).execute({})
 
+    def test_hook_params(self, mock_parser):
+        op = S3ToSqlOperator(
+            parser=mock_parser,
+            sql_hook_params={
+                "log_sql": False,
+            },
+            **self.s3_to_sql_transfer_kwargs,
+        )
+        hook = op.db_hook
+        assert hook.log_sql == op.sql_hook_params["log_sql"]
+
     def teardown_method(self):
         with create_session() as session:
             (

Reply via email to