This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ccb9d04  Add hook_params in BaseSqlOperator (#18718)
ccb9d04 is described below

commit ccb9d04c22dfbe1efa8d0b690ddd929894acff9a
Author: Benjamin <[email protected]>
AuthorDate: Mon Nov 15 00:22:04 2021 -0500

    Add hook_params in BaseSqlOperator (#18718)
---
 airflow/models/connection.py |  8 +++++---
 airflow/operators/sql.py     | 12 ++++++++++--
 tests/operators/test_sql.py  | 23 +++++++++++++++++++++++
 3 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 5357088..19ba82d 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -285,8 +285,8 @@ class Connection(Base, LoggingMixin):
         if self._extra and self.is_extra_encrypted:
             self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()
 
-    def get_hook(self):
-        """Return hook based on conn_type."""
+    def get_hook(self, *, hook_kwargs=None):
+        """Return hook based on conn_type"""
         (
             hook_class_name,
             conn_id_param,
@@ -304,7 +304,9 @@ class Connection(Base, LoggingMixin):
                 "Could not import %s when discovering %s %s", hook_class_name, 
hook_name, package_name
             )
             raise
-        return hook_class(**{conn_id_param: self.conn_id})
+        if hook_kwargs is None:
+            hook_kwargs = {}
+        return hook_class(**{conn_id_param: self.conn_id}, **hook_kwargs)
 
     def __repr__(self):
         return self.conn_id
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index ab916c0..e44208b 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -46,10 +46,18 @@ class BaseSQLOperator(BaseOperator):
     You can custom the behavior by overriding the .get_db_hook() method.
     """
 
-    def __init__(self, *, conn_id: Optional[str] = None, database: 
Optional[str] = None, **kwargs):
+    def __init__(
+        self,
+        *,
+        conn_id: Optional[str] = None,
+        database: Optional[str] = None,
+        hook_params: Optional[Dict] = None,
+        **kwargs,
+    ):
         super().__init__(**kwargs)
         self.conn_id = conn_id
         self.database = database
+        self.hook_params = {} if hook_params is None else hook_params
 
     @cached_property
     def _hook(self):
@@ -57,7 +65,7 @@ class BaseSQLOperator(BaseOperator):
         self.log.debug("Get connection for %s", self.conn_id)
         conn = BaseHook.get_connection(self.conn_id)
 
-        hook = conn.get_hook()
+        hook = conn.get_hook(hook_kwargs=self.hook_params)
         if not isinstance(hook, DbApiHook):
             raise AirflowException(
                 f'The connection type is not supported by 
{self.__class__.__name__}. '
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index 3706fd6..05f296d 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -90,6 +90,29 @@ class TestSQLCheckOperatorDbHook:
         with pytest.raises(AirflowException, match=r"The connection type is 
not supported"):
             self._operator._hook
 
+    def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(conn_id='snowflake_default', 
conn_type='snowflake')
+        self._operator.hook_params = {
+            'warehouse': 'warehouse',
+            'database': 'database',
+            'role': 'role',
+            'schema': 'schema',
+        }
+        assert self._operator._hook.conn_type == 'snowflake'
+        assert self._operator._hook.warehouse == 'warehouse'
+        assert self._operator._hook.database == 'database'
+        assert self._operator._hook.role == 'role'
+        assert self._operator._hook.schema == 'schema'
+
+    def test_sql_operator_hook_params_biguery(self, mock_get_conn):
+        mock_get_conn.return_value = Connection(
+            conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
+        )
+        self._operator.hook_params = {'use_legacy_sql': True, 'location': 
'us-east1'}
+        assert self._operator._hook.conn_type == 'gcpbigquery'
+        assert self._operator._hook.use_legacy_sql
+        assert self._operator._hook.location == 'us-east1'
+
 
 class TestCheckOperator(unittest.TestCase):
     def setUp(self):

Reply via email to