This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 45d5f64127 Add `sql_hook_params` parameter to `SqlToS3Operator` 
(#33425)
45d5f64127 is described below

commit 45d5f6412731f81002be7e9c86c11060394875cf
Author: Alex Begg <[email protected]>
AuthorDate: Wed Aug 16 07:22:12 2023 -0700

    Add `sql_hook_params` parameter to `SqlToS3Operator` (#33425)
    
    Adding `sql_hook_params` parameter to `SqlToS3Operator`. This will allow 
you to pass extra config params to the underlying SQL hook.
---
 airflow/providers/amazon/aws/transfers/sql_to_s3.py    |  6 +++++-
 tests/providers/amazon/aws/transfers/test_sql_to_s3.py | 18 ++++++++++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py 
b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index 1b327def97..92c1906629 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -65,6 +65,8 @@ class SqlToS3Operator(BaseOperator):
     :param s3_key: desired key for the file. It includes the name of the file. 
(templated)
     :param replace: whether or not to replace the file in S3 if it previously 
existed
     :param sql_conn_id: reference to a specific database.
+    :param sql_hook_params: Extra config params to be passed to the underlying 
hook.
+        Should match the desired hook constructor params.
     :param parameters: (optional) the parameters to render the SQL query with.
     :param aws_conn_id: reference to a specific S3 connection
     :param verify: Whether or not to verify SSL certificates for S3 connection.
@@ -100,6 +102,7 @@ class SqlToS3Operator(BaseOperator):
         s3_bucket: str,
         s3_key: str,
         sql_conn_id: str,
+        sql_hook_params: dict | None = None,
         parameters: None | Mapping | Iterable = None,
         replace: bool = False,
         aws_conn_id: str = "aws_default",
@@ -120,6 +123,7 @@ class SqlToS3Operator(BaseOperator):
         self.pd_kwargs = pd_kwargs or {}
         self.parameters = parameters
         self.groupby_kwargs = groupby_kwargs or {}
+        self.sql_hook_params = sql_hook_params
 
         if "path_or_buf" in self.pd_kwargs:
             raise AirflowException("The argument path_or_buf is not allowed, 
please remove it")
@@ -200,7 +204,7 @@ class SqlToS3Operator(BaseOperator):
     def _get_hook(self) -> DbApiHook:
         self.log.debug("Get connection for %s", self.sql_conn_id)
         conn = BaseHook.get_connection(self.sql_conn_id)
-        hook = conn.get_hook()
+        hook = conn.get_hook(hook_params=self.sql_hook_params)
         if not callable(getattr(hook, "get_pandas_df", None)):
             raise AirflowException(
                 "This hook is not supported. The hook class must have 
get_pandas_df method."
diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
index a0e4e6f603..cc56fd064a 100644
--- a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
@@ -24,6 +24,7 @@ import pandas as pd
 import pytest
 
 from airflow.exceptions import AirflowException
+from airflow.models import Connection
 from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator
 
 
@@ -269,3 +270,20 @@ class TestSqlToS3Operator:
                 }
             )
         )
+
+    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
+    def test_hook_params(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(conn_id="postgres_test", 
conn_type="postgres")
+        op = SqlToS3Operator(
+            query="query",
+            s3_bucket="bucket",
+            s3_key="key",
+            sql_conn_id="postgres_test",
+            task_id="task_id",
+            sql_hook_params={
+                "log_sql": False,
+            },
+            dag=None,
+        )
+        hook = op._get_hook()
+        assert hook.log_sql == op.sql_hook_params["log_sql"]

Reply via email to