This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2ab78ec441 Fix connection parameters of `SnowflakeValueCheckOperator`
(#32605)
2ab78ec441 is described below
commit 2ab78ec441a748ae4d99e429fe336b80a601d7b1
Author: Marcin Molak <[email protected]>
AuthorDate: Mon Jul 31 21:21:00 2023 +0200
Fix connection parameters of `SnowflakeValueCheckOperator` (#32605)
---
airflow/providers/snowflake/operators/snowflake.py | 65 +++++++++----------
.../snowflake/operators/test_snowflake.py | 73 ++++++++++++++++++++++
2 files changed, 106 insertions(+), 32 deletions(-)
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index 1de82218ff..8f29eefd5b 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -212,18 +212,18 @@ class SnowflakeCheckOperator(SQLCheckOperator):
session_parameters: dict | None = None,
**kwargs,
) -> None:
+ if any([warehouse, database, role, schema, authenticator,
session_parameters]):
+ hook_params = kwargs.pop("hook_params", {})
+ kwargs["hook_params"] = {
+ "warehouse": warehouse,
+ "database": database,
+ "role": role,
+ "schema": schema,
+ "authenticator": authenticator,
+ "session_parameters": session_parameters,
+ **hook_params,
+ }
super().__init__(sql=sql, parameters=parameters,
conn_id=snowflake_conn_id, **kwargs)
- self.snowflake_conn_id = snowflake_conn_id
- self.sql = sql
- self.autocommit = autocommit
- self.do_xcom_push = do_xcom_push
- self.parameters = parameters
- self.warehouse = warehouse
- self.database = database
- self.role = role
- self.schema = schema
- self.authenticator = authenticator
- self.session_parameters = session_parameters
self.query_ids: list[str] = []
@@ -277,20 +277,20 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
session_parameters: dict | None = None,
**kwargs,
) -> None:
+ if any([warehouse, database, role, schema, authenticator,
session_parameters]):
+ hook_params = kwargs.pop("hook_params", {})
+ kwargs["hook_params"] = {
+ "warehouse": warehouse,
+ "database": database,
+ "role": role,
+ "schema": schema,
+ "authenticator": authenticator,
+ "session_parameters": session_parameters,
+ **hook_params,
+ }
super().__init__(
sql=sql, pass_value=pass_value, tolerance=tolerance,
conn_id=snowflake_conn_id, **kwargs
)
- self.snowflake_conn_id = snowflake_conn_id
- self.sql = sql
- self.autocommit = autocommit
- self.do_xcom_push = do_xcom_push
- self.parameters = parameters
- self.warehouse = warehouse
- self.database = database
- self.role = role
- self.schema = schema
- self.authenticator = authenticator
- self.session_parameters = session_parameters
self.query_ids: list[str] = []
@@ -352,6 +352,17 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
session_parameters: dict | None = None,
**kwargs,
) -> None:
+ if any([warehouse, database, role, schema, authenticator,
session_parameters]):
+ hook_params = kwargs.pop("hook_params", {})
+ kwargs["hook_params"] = {
+ "warehouse": warehouse,
+ "database": database,
+ "role": role,
+ "schema": schema,
+ "authenticator": authenticator,
+ "session_parameters": session_parameters,
+ **hook_params,
+ }
super().__init__(
table=table,
metrics_thresholds=metrics_thresholds,
@@ -360,16 +371,6 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
conn_id=snowflake_conn_id,
**kwargs,
)
- self.snowflake_conn_id = snowflake_conn_id
- self.autocommit = autocommit
- self.do_xcom_push = do_xcom_push
- self.parameters = parameters
- self.warehouse = warehouse
- self.database = database
- self.role = role
- self.schema = schema
- self.authenticator = authenticator
- self.session_parameters = session_parameters
self.query_ids: list[str] = []
diff --git a/tests/providers/snowflake/operators/test_snowflake.py
b/tests/providers/snowflake/operators/test_snowflake.py
index 41cbfe6717..ea9d8333f3 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -72,6 +72,36 @@ class TestSnowflakeOperator:
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)
+class TestSnowflakeOperatorForParams:
+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
+ def test_overwrite_params(self, mock_base_op):
+ sql = "Select * from test_table"
+ SnowflakeOperator(
+ sql=sql,
+ task_id="snowflake_params_check",
+ snowflake_conn_id="snowflake_default",
+ warehouse="test_warehouse",
+ database="test_database",
+ role="test_role",
+ schema="test_schema",
+ authenticator="oath",
+ session_parameters={"QUERY_TAG": "test_tag"},
+ )
+ mock_base_op.assert_called_once_with(
+ conn_id="snowflake_default",
+ task_id="snowflake_params_check",
+ hook_params={
+ "warehouse": "test_warehouse",
+ "database": "test_database",
+ "role": "test_role",
+ "schema": "test_schema",
+ "authenticator": "oath",
+ "session_parameters": {"QUERY_TAG": "test_tag"},
+ },
+ default_args={},
+ )
+
+
@pytest.mark.parametrize(
"operator_class, kwargs",
[
@@ -93,6 +123,49 @@ class TestSnowflakeCheckOperators:
mock_get_db_hook.assert_called_once()
[email protected](
+ "operator_class, kwargs",
+ [
+ (SnowflakeCheckOperator, dict(sql="Select * from test_table")),
+ (SnowflakeValueCheckOperator, dict(sql="Select * from test_table",
pass_value=95)),
+ (SnowflakeIntervalCheckOperator, dict(table="test-table-id",
metrics_thresholds={"COUNT(*)": 1.5})),
+ ],
+)
+class TestSnowflakeCheckOperatorsForParams:
+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
+ def test_overwrite_params(
+ self,
+ mock_base_op,
+ operator_class,
+ kwargs,
+ ):
+ operator_class(
+ task_id="snowflake_params_check",
+ snowflake_conn_id="snowflake_default",
+ warehouse="test_warehouse",
+ database="test_database",
+ role="test_role",
+ schema="test_schema",
+ authenticator="oath",
+ session_parameters={"QUERY_TAG": "test_tag"},
+ **kwargs,
+ )
+ mock_base_op.assert_called_once_with(
+ conn_id="snowflake_default",
+ database=None,
+ task_id="snowflake_params_check",
+ hook_params={
+ "warehouse": "test_warehouse",
+ "database": "test_database",
+ "role": "test_role",
+ "schema": "test_schema",
+ "authenticator": "oath",
+ "session_parameters": {"QUERY_TAG": "test_tag"},
+ },
+ default_args={},
+ )
+
+
def create_context(task, dag=None):
if dag is None:
dag = DAG(dag_id="dag")