This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b7f21d3eeda Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator 
to use parameters arg correctly (#53837)
b7f21d3eeda is described below

commit b7f21d3eeda5102966b293c436164026e7b81e77
Author: GPK <[email protected]>
AuthorDate: Tue Jul 29 13:15:02 2025 +0100

    Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator to use 
parameters arg correctly (#53837)
    
    * Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator to use 
parameters arg correctly
    
    * Fix tests
---
 .../airflow/providers/common/sql/operators/sql.py  |  4 +-
 .../tests/unit/common/sql/operators/test_sql.py    |  4 +-
 .../providers/snowflake/hooks/snowflake_sql_api.py |  2 +-
 .../providers/snowflake/operators/snowflake.py     | 16 ++--
 .../unit/snowflake/operators/test_snowflake.py     | 89 ++++++++++++++++++----
 5 files changed, 85 insertions(+), 30 deletions(-)

diff --git 
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py 
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index 95d602bc9b4..0b884d97f0c 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -863,6 +863,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
         tolerance: Any = None,
         conn_id: str | None = None,
         database: str | None = None,
+        parameters: Iterable | Mapping[str, Any] | None = None,
         **kwargs,
     ):
         super().__init__(conn_id=conn_id, database=database, **kwargs)
@@ -871,6 +872,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
         tol = _convert_to_float_if_possible(tolerance)
         self.tol = tol if isinstance(tol, float) else None
         self.has_tolerance = self.tol is not None
+        self.parameters = parameters
 
     def check_value(self, records):
         if not records:
@@ -903,7 +905,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
 
     def execute(self, context: Context):
         self.log.info("Executing SQL check: %s", self.sql)
-        records = self.get_db_hook().get_first(self.sql)
+        records = self.get_db_hook().get_first(self.sql, self.parameters)
         self.check_value(records)
 
     def _to_float(self, records):
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py 
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index 23065cd3a71..dd404d5f59e 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -51,7 +51,7 @@ from airflow.providers.common.sql.operators.sql import (
 )
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.utils import timezone
+from airflow.utils import timezone  # type: ignore[attr-defined]
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
@@ -844,7 +844,7 @@ class TestValueCheckOperator:
 
         operator.execute(None)
 
-        mock_hook.get_first.assert_called_once_with(sql)
+        mock_hook.get_first.assert_called_once_with(sql, None)
 
     @mock.patch.object(SQLValueCheckOperator, "get_db_hook")
     def test_execute_fail(self, mock_get_db_hook):
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 48747381602..98349127204 100644
--- 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -435,7 +435,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
         :param url: The URL for the API endpoint.
         :param headers: The headers to include in the API call.
         :param params: (Optional) The query parameters to include in the API 
call.
-        :param data: (Optional) The data to include in the API call.
+        :param json: (Optional) The data to include in the API call.
         :return: The response object from the API call.
         """
         with requests.Session() as session:
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
index 4f214c681fb..84f6773b2bf 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
@@ -76,8 +76,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         Template references are recognized by str ending in '.sql'
     :param snowflake_conn_id: Reference to
         :ref:`Snowflake connection id<howto/connection:snowflake>`
-    :param autocommit: if True, each command is automatically committed.
-        (default value: True)
     :param parameters: (optional) the parameters to render the SQL query with.
     :param warehouse: name of warehouse (will overwrite any warehouse
         defined in the connection's extra JSON)
@@ -109,8 +107,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         sql: str,
         snowflake_conn_id: str = "snowflake_default",
         parameters: Iterable | Mapping[str, Any] | None = None,
-        autocommit: bool = True,
-        do_xcom_push: bool = True,
         warehouse: str | None = None,
         database: str | None = None,
         role: str | None = None,
@@ -179,8 +175,6 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
         tolerance: Any = None,
         snowflake_conn_id: str = "snowflake_default",
         parameters: Iterable | Mapping[str, Any] | None = None,
-        autocommit: bool = True,
-        do_xcom_push: bool = True,
         warehouse: str | None = None,
         database: str | None = None,
         role: str | None = None,
@@ -202,7 +196,12 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
                 **hook_params,
             }
         super().__init__(
-            sql=sql, pass_value=pass_value, tolerance=tolerance, 
conn_id=snowflake_conn_id, **kwargs
+            sql=sql,
+            pass_value=pass_value,
+            tolerance=tolerance,
+            conn_id=snowflake_conn_id,
+            parameters=parameters,
+            **kwargs,
         )
         self.query_ids: list[str] = []
 
@@ -259,9 +258,6 @@ class 
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
         date_filter_column: str = "ds",
         days_back: SupportsAbs[int] = -7,
         snowflake_conn_id: str = "snowflake_default",
-        parameters: Iterable | Mapping[str, Any] | None = None,
-        autocommit: bool = True,
-        do_xcom_push: bool = True,
         warehouse: str | None = None,
         database: str | None = None,
         role: str | None = None,
diff --git 
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index 721b72e5781..c4711ccc26d 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -18,11 +18,13 @@
 from __future__ import annotations
 
 from unittest import mock
+from unittest.mock import call
 
 import pendulum
 import pytest
 
 from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import Connection
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
@@ -34,7 +36,7 @@ from airflow.providers.snowflake.operators.snowflake import (
     SnowflakeValueCheckOperator,
 )
 from airflow.providers.snowflake.triggers.snowflake_trigger import 
SnowflakeSqlApiTrigger
-from airflow.utils import timezone
+from airflow.utils import timezone  # type: ignore[attr-defined]
 from airflow.utils.types import DagRunType
 
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
@@ -107,25 +109,80 @@ class TestSnowflakeOperatorForParams:
         )
 
 
[email protected](
-    "operator_class, kwargs",
-    [
-        (SnowflakeCheckOperator, dict(sql="Select * from test_table")),
-        (SnowflakeValueCheckOperator, dict(sql="Select * from test_table", 
pass_value=95)),
-        (SnowflakeIntervalCheckOperator, dict(table="test-table-id", 
metrics_thresholds={"COUNT(*)": 1.5})),
-    ],
-)
-class TestSnowflakeCheckOperators:
-    
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
[email protected](autouse=True)
+def setup_connections(create_connection_without_db):
+    create_connection_without_db(
+        Connection(
+            conn_id="snowflake_default",
+            conn_type="snowflake",
+            host="test_host",
+            port=443,
+            schema="test_schema",
+            login="test_user",
+            password="test_password",
+        )
+    )
+
+
+class TestSnowflakeCheckOperator:
+    
@mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckOperator.get_db_hook")
     def test_get_db_hook(
         self,
         mock_get_db_hook,
-        operator_class,
-        kwargs,
     ):
-        operator = operator_class(task_id="snowflake_check", 
snowflake_conn_id="snowflake_default", **kwargs)
-        operator.get_db_hook()
-        mock_get_db_hook.assert_called_once()
+        operator = SnowflakeCheckOperator(
+            task_id="snowflake_check",
+            snowflake_conn_id="snowflake_default",
+            sql="Select * from test_table",
+            parameters={"param1": "value1"},
+        )
+        operator.execute({})
+        mock_get_db_hook.assert_has_calls(
+            [call().get_first("Select * from test_table", {"param1": 
"value1"})]
+        )
+
+
+class TestSnowflakeValueCheckOperator:
+    
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.get_db_hook")
+    
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.check_value")
+    def test_get_db_hook(
+        self,
+        mock_check_value,
+        mock_get_db_hook,
+    ):
+        mock_get_db_hook.return_value.get_first.return_value = ["test_value"]
+
+        operator = SnowflakeValueCheckOperator(
+            task_id="snowflake_check",
+            sql="Select * from test_table",
+            pass_value=95,
+            parameters={"param1": "value1"},
+        )
+        operator.execute({})
+        mock_get_db_hook.assert_has_calls(
+            [call().get_first("Select * from test_table", {"param1": 
"value1"})]
+        )
+        assert mock_check_value.call_args == call(["test_value"])
+
+
+class TestSnowflakeIntervalCheckOperator:
+    
@mock.patch("airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator.__init__")
+    def test_get_db_hook(
+        self,
+        mock_snowflake_interval_check_operator,
+    ):
+        SnowflakeIntervalCheckOperator(
+            task_id="snowflake_check", table="test-table-id", 
metrics_thresholds={"COUNT(*)": 1.5}
+        )
+        assert mock_snowflake_interval_check_operator.call_args == mock.call(
+            table="test-table-id",
+            metrics_thresholds={"COUNT(*)": 1.5},
+            date_filter_column="ds",
+            days_back=-7,
+            conn_id="snowflake_default",
+            task_id="snowflake_check",
+            default_args={},
+        )
 
 
 @pytest.mark.parametrize(

Reply via email to