This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ccb9d04 Add hook_params in BaseSqlOperator (#18718)
ccb9d04 is described below
commit ccb9d04c22dfbe1efa8d0b690ddd929894acff9a
Author: Benjamin <[email protected]>
AuthorDate: Mon Nov 15 00:22:04 2021 -0500
Add hook_params in BaseSqlOperator (#18718)
---
airflow/models/connection.py | 8 +++++---
airflow/operators/sql.py | 12 ++++++++++--
tests/operators/test_sql.py | 23 +++++++++++++++++++++++
3 files changed, 38 insertions(+), 5 deletions(-)
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 5357088..19ba82d 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -285,8 +285,8 @@ class Connection(Base, LoggingMixin):
if self._extra and self.is_extra_encrypted:
self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()
- def get_hook(self):
- """Return hook based on conn_type."""
+ def get_hook(self, *, hook_kwargs=None):
+ """Return hook based on conn_type"""
(
hook_class_name,
conn_id_param,
@@ -304,7 +304,9 @@ class Connection(Base, LoggingMixin):
"Could not import %s when discovering %s %s", hook_class_name,
hook_name, package_name
)
raise
- return hook_class(**{conn_id_param: self.conn_id})
+ if hook_kwargs is None:
+ hook_kwargs = {}
+ return hook_class(**{conn_id_param: self.conn_id}, **hook_kwargs)
def __repr__(self):
return self.conn_id
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index ab916c0..e44208b 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -46,10 +46,18 @@ class BaseSQLOperator(BaseOperator):
You can custom the behavior by overriding the .get_db_hook() method.
"""
- def __init__(self, *, conn_id: Optional[str] = None, database:
Optional[str] = None, **kwargs):
+ def __init__(
+ self,
+ *,
+ conn_id: Optional[str] = None,
+ database: Optional[str] = None,
+ hook_params: Optional[Dict] = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.conn_id = conn_id
self.database = database
+ self.hook_params = {} if hook_params is None else hook_params
@cached_property
def _hook(self):
@@ -57,7 +65,7 @@ class BaseSQLOperator(BaseOperator):
self.log.debug("Get connection for %s", self.conn_id)
conn = BaseHook.get_connection(self.conn_id)
- hook = conn.get_hook()
+ hook = conn.get_hook(hook_kwargs=self.hook_params)
if not isinstance(hook, DbApiHook):
raise AirflowException(
f'The connection type is not supported by
{self.__class__.__name__}. '
diff --git a/tests/operators/test_sql.py b/tests/operators/test_sql.py
index 3706fd6..05f296d 100644
--- a/tests/operators/test_sql.py
+++ b/tests/operators/test_sql.py
@@ -90,6 +90,29 @@ class TestSQLCheckOperatorDbHook:
with pytest.raises(AirflowException, match=r"The connection type is
not supported"):
self._operator._hook
+ def test_sql_operator_hook_params_snowflake(self, mock_get_conn):
+ mock_get_conn.return_value = Connection(conn_id='snowflake_default',
conn_type='snowflake')
+ self._operator.hook_params = {
+ 'warehouse': 'warehouse',
+ 'database': 'database',
+ 'role': 'role',
+ 'schema': 'schema',
+ }
+ assert self._operator._hook.conn_type == 'snowflake'
+ assert self._operator._hook.warehouse == 'warehouse'
+ assert self._operator._hook.database == 'database'
+ assert self._operator._hook.role == 'role'
+ assert self._operator._hook.schema == 'schema'
+
+ def test_sql_operator_hook_params_biguery(self, mock_get_conn):
+ mock_get_conn.return_value = Connection(
+ conn_id='google_cloud_bigquery_default', conn_type='gcpbigquery'
+ )
+ self._operator.hook_params = {'use_legacy_sql': True, 'location':
'us-east1'}
+ assert self._operator._hook.conn_type == 'gcpbigquery'
+ assert self._operator._hook.use_legacy_sql
+ assert self._operator._hook.location == 'us-east1'
+
class TestCheckOperator(unittest.TestCase):
def setUp(self):