This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ab78ec441 Fix connection parameters of `SnowflakeValueCheckOperator` 
(#32605)
2ab78ec441 is described below

commit 2ab78ec441a748ae4d99e429fe336b80a601d7b1
Author: Marcin Molak <[email protected]>
AuthorDate: Mon Jul 31 21:21:00 2023 +0200

    Fix connection parameters of `SnowflakeValueCheckOperator` (#32605)
---
 airflow/providers/snowflake/operators/snowflake.py | 65 +++++++++----------
 .../snowflake/operators/test_snowflake.py          | 73 ++++++++++++++++++++++
 2 files changed, 106 insertions(+), 32 deletions(-)

diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index 1de82218ff..8f29eefd5b 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -212,18 +212,18 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         session_parameters: dict | None = None,
         **kwargs,
     ) -> None:
+        if any([warehouse, database, role, schema, authenticator, 
session_parameters]):
+            hook_params = kwargs.pop("hook_params", {})
+            kwargs["hook_params"] = {
+                "warehouse": warehouse,
+                "database": database,
+                "role": role,
+                "schema": schema,
+                "authenticator": authenticator,
+                "session_parameters": session_parameters,
+                **hook_params,
+            }
         super().__init__(sql=sql, parameters=parameters, 
conn_id=snowflake_conn_id, **kwargs)
-        self.snowflake_conn_id = snowflake_conn_id
-        self.sql = sql
-        self.autocommit = autocommit
-        self.do_xcom_push = do_xcom_push
-        self.parameters = parameters
-        self.warehouse = warehouse
-        self.database = database
-        self.role = role
-        self.schema = schema
-        self.authenticator = authenticator
-        self.session_parameters = session_parameters
         self.query_ids: list[str] = []
 
 
@@ -277,20 +277,20 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
         session_parameters: dict | None = None,
         **kwargs,
     ) -> None:
+        if any([warehouse, database, role, schema, authenticator, 
session_parameters]):
+            hook_params = kwargs.pop("hook_params", {})
+            kwargs["hook_params"] = {
+                "warehouse": warehouse,
+                "database": database,
+                "role": role,
+                "schema": schema,
+                "authenticator": authenticator,
+                "session_parameters": session_parameters,
+                **hook_params,
+            }
         super().__init__(
             sql=sql, pass_value=pass_value, tolerance=tolerance, 
conn_id=snowflake_conn_id, **kwargs
         )
-        self.snowflake_conn_id = snowflake_conn_id
-        self.sql = sql
-        self.autocommit = autocommit
-        self.do_xcom_push = do_xcom_push
-        self.parameters = parameters
-        self.warehouse = warehouse
-        self.database = database
-        self.role = role
-        self.schema = schema
-        self.authenticator = authenticator
-        self.session_parameters = session_parameters
         self.query_ids: list[str] = []
 
 
@@ -352,6 +352,17 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
         session_parameters: dict | None = None,
         **kwargs,
     ) -> None:
+        if any([warehouse, database, role, schema, authenticator, 
session_parameters]):
+            hook_params = kwargs.pop("hook_params", {})
+            kwargs["hook_params"] = {
+                "warehouse": warehouse,
+                "database": database,
+                "role": role,
+                "schema": schema,
+                "authenticator": authenticator,
+                "session_parameters": session_parameters,
+                **hook_params,
+            }
         super().__init__(
             table=table,
             metrics_thresholds=metrics_thresholds,
@@ -360,16 +371,6 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
             conn_id=snowflake_conn_id,
             **kwargs,
         )
-        self.snowflake_conn_id = snowflake_conn_id
-        self.autocommit = autocommit
-        self.do_xcom_push = do_xcom_push
-        self.parameters = parameters
-        self.warehouse = warehouse
-        self.database = database
-        self.role = role
-        self.schema = schema
-        self.authenticator = authenticator
-        self.session_parameters = session_parameters
         self.query_ids: list[str] = []
 
 
diff --git a/tests/providers/snowflake/operators/test_snowflake.py 
b/tests/providers/snowflake/operators/test_snowflake.py
index 41cbfe6717..ea9d8333f3 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -72,6 +72,36 @@ class TestSnowflakeOperator:
         operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
 
 
+class TestSnowflakeOperatorForParams:
+    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
+    def test_overwrite_params(self, mock_base_op):
+        sql = "Select * from test_table"
+        SnowflakeOperator(
+            sql=sql,
+            task_id="snowflake_params_check",
+            snowflake_conn_id="snowflake_default",
+            warehouse="test_warehouse",
+            database="test_database",
+            role="test_role",
+            schema="test_schema",
+            authenticator="oath",
+            session_parameters={"QUERY_TAG": "test_tag"},
+        )
+        mock_base_op.assert_called_once_with(
+            conn_id="snowflake_default",
+            task_id="snowflake_params_check",
+            hook_params={
+                "warehouse": "test_warehouse",
+                "database": "test_database",
+                "role": "test_role",
+                "schema": "test_schema",
+                "authenticator": "oath",
+                "session_parameters": {"QUERY_TAG": "test_tag"},
+            },
+            default_args={},
+        )
+
+
 @pytest.mark.parametrize(
     "operator_class, kwargs",
     [
@@ -93,6 +123,49 @@ class TestSnowflakeCheckOperators:
         mock_get_db_hook.assert_called_once()
 
 
[email protected](
+    "operator_class, kwargs",
+    [
+        (SnowflakeCheckOperator, dict(sql="Select * from test_table")),
+        (SnowflakeValueCheckOperator, dict(sql="Select * from test_table", 
pass_value=95)),
+        (SnowflakeIntervalCheckOperator, dict(table="test-table-id", 
metrics_thresholds={"COUNT(*)": 1.5})),
+    ],
+)
+class TestSnowflakeCheckOperatorsForParams:
+    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
+    def test_overwrite_params(
+        self,
+        mock_base_op,
+        operator_class,
+        kwargs,
+    ):
+        operator_class(
+            task_id="snowflake_params_check",
+            snowflake_conn_id="snowflake_default",
+            warehouse="test_warehouse",
+            database="test_database",
+            role="test_role",
+            schema="test_schema",
+            authenticator="oath",
+            session_parameters={"QUERY_TAG": "test_tag"},
+            **kwargs,
+        )
+        mock_base_op.assert_called_once_with(
+            conn_id="snowflake_default",
+            database=None,
+            task_id="snowflake_params_check",
+            hook_params={
+                "warehouse": "test_warehouse",
+                "database": "test_database",
+                "role": "test_role",
+                "schema": "test_schema",
+                "authenticator": "oath",
+                "session_parameters": {"QUERY_TAG": "test_tag"},
+            },
+            default_args={},
+        )
+
+
 def create_context(task, dag=None):
     if dag is None:
         dag = DAG(dag_id="dag")

Reply via email to