This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1407e277ae Add `sql_hook_params` parameter to `S3ToSqlOperator`
(#33427)
1407e277ae is described below
commit 1407e277aeb059cbfd1bb96fb3f43c4bf4f15cea
Author: Alex Begg <[email protected]>
AuthorDate: Sat Aug 19 00:52:40 2023 -0700
Add `sql_hook_params` parameter to `S3ToSqlOperator` (#33427)
Adding `sql_hook_params` parameter to `SqlToS3Operator`. This will allow
you to pass extra config params to the underlying SQL hook.
This uses the same "sql_hook_params" parameter name as already used in
`SqlToSlackOperator`.
---
airflow/providers/amazon/aws/transfers/s3_to_sql.py | 7 ++++++-
tests/providers/amazon/aws/transfers/test_s3_to_sql.py | 17 ++++++++++++++---
2 files changed, 20 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sql.py
b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
index 8e0613ea6d..667e3c174b 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_sql.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_sql.py
@@ -44,6 +44,8 @@ class S3ToSqlOperator(BaseOperator):
:param s3_bucket: reference to a specific S3 bucket
:param s3_key: reference to a specific S3 key
:param sql_conn_id: reference to a specific SQL database. Must be of type
DBApiHook
+ :param sql_hook_params: Extra config params to be passed to the underlying
hook.
+ Should match the desired hook constructor params.
:param aws_conn_id: reference to a specific S3 / AWS connection
:param column_list: list of column names to use in the insert SQL.
:param commit_every: The maximum number of rows to insert in one
@@ -83,6 +85,7 @@ class S3ToSqlOperator(BaseOperator):
commit_every: int = 1000,
schema: str | None = None,
sql_conn_id: str = "sql_default",
+ sql_hook_params: dict | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
@@ -96,6 +99,7 @@ class S3ToSqlOperator(BaseOperator):
self.column_list = column_list
self.commit_every = commit_every
self.parser = parser
+ self.sql_hook_params = sql_hook_params
def execute(self, context: Context) -> None:
self.log.info("Loading %s to SQL table %s...", self.s3_key, self.table)
@@ -120,7 +124,8 @@ class S3ToSqlOperator(BaseOperator):
@cached_property
def db_hook(self):
self.log.debug("Get connection for %s", self.sql_conn_id)
- hook = BaseHook.get_hook(self.sql_conn_id)
+ conn = BaseHook.get_connection(self.sql_conn_id)
+ hook = conn.get_hook(hook_params=self.sql_hook_params)
if not callable(getattr(hook, "insert_rows", None)):
raise AirflowException(
"This hook is not supported. The hook class must have an
`insert_rows` method."
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
b/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
index 0239624034..3e6e6913e7 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_sql.py
@@ -76,7 +76,7 @@ class TestS3ToSqlTransfer:
return bad_hook
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
- @patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook")
+ @patch("airflow.models.connection.Connection.get_hook")
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
def test_execute(self, mock_get_key, mock_hook, mock_tempfile,
mock_parser):
@@ -93,7 +93,7 @@ class TestS3ToSqlTransfer:
mock_parser.assert_called_once_with(mock_tempfile.return_value.__enter__.return_value.name)
- mock_hook.get_hook.return_value.insert_rows.assert_called_once_with(
+ mock_hook.return_value.insert_rows.assert_called_once_with(
table=self.s3_to_sql_transfer_kwargs["table"],
schema=self.s3_to_sql_transfer_kwargs["schema"],
target_fields=self.s3_to_sql_transfer_kwargs["column_list"],
@@ -102,13 +102,24 @@ class TestS3ToSqlTransfer:
)
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.NamedTemporaryFile")
-
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.BaseHook.get_hook",
return_value=mock_bad_hook)
+ @patch("airflow.models.connection.Connection.get_hook",
return_value=mock_bad_hook)
@patch("airflow.providers.amazon.aws.transfers.s3_to_sql.S3Hook.get_key")
def test_execute_with_bad_hook(self, mock_get_key, mock_bad_hook,
mock_tempfile, mock_parser):
with pytest.raises(AirflowException):
S3ToSqlOperator(parser=mock_parser,
**self.s3_to_sql_transfer_kwargs).execute({})
+ def test_hook_params(self, mock_parser):
+ op = S3ToSqlOperator(
+ parser=mock_parser,
+ sql_hook_params={
+ "log_sql": False,
+ },
+ **self.s3_to_sql_transfer_kwargs,
+ )
+ hook = op.db_hook
+ assert hook.log_sql == op.sql_hook_params["log_sql"]
+
def teardown_method(self):
with create_session() as session:
(