This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b7f21d3eeda Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator
to use parameters arg correctly (#53837)
b7f21d3eeda is described below
commit b7f21d3eeda5102966b293c436164026e7b81e77
Author: GPK <[email protected]>
AuthorDate: Tue Jul 29 13:15:02 2025 +0100
Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator to use
parameters arg correctly (#53837)
* Fix SnowflakeCheckOperator and SnowflakeValueCheckOperator to use
parameters arg correctly
* Fix tests
---
.../airflow/providers/common/sql/operators/sql.py | 4 +-
.../tests/unit/common/sql/operators/test_sql.py | 4 +-
.../providers/snowflake/hooks/snowflake_sql_api.py | 2 +-
.../providers/snowflake/operators/snowflake.py | 16 ++--
.../unit/snowflake/operators/test_snowflake.py | 89 ++++++++++++++++++----
5 files changed, 85 insertions(+), 30 deletions(-)
diff --git
a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
index 95d602bc9b4..0b884d97f0c 100644
--- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
+++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py
@@ -863,6 +863,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
tolerance: Any = None,
conn_id: str | None = None,
database: str | None = None,
+ parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)
@@ -871,6 +872,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
tol = _convert_to_float_if_possible(tolerance)
self.tol = tol if isinstance(tol, float) else None
self.has_tolerance = self.tol is not None
+ self.parameters = parameters
def check_value(self, records):
if not records:
@@ -903,7 +905,7 @@ class SQLValueCheckOperator(BaseSQLOperator):
def execute(self, context: Context):
self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
+ records = self.get_db_hook().get_first(self.sql, self.parameters)
self.check_value(records)
def _to_float(self, records):
diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
index 23065cd3a71..dd404d5f59e 100644
--- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
+++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py
@@ -51,7 +51,7 @@ from airflow.providers.common.sql.operators.sql import (
)
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.utils import timezone
+from airflow.utils import timezone # type: ignore[attr-defined]
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
@@ -844,7 +844,7 @@ class TestValueCheckOperator:
operator.execute(None)
- mock_hook.get_first.assert_called_once_with(sql)
+ mock_hook.get_first.assert_called_once_with(sql, None)
@mock.patch.object(SQLValueCheckOperator, "get_db_hook")
def test_execute_fail(self, mock_get_db_hook):
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 48747381602..98349127204 100644
---
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -435,7 +435,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
:param url: The URL for the API endpoint.
:param headers: The headers to include in the API call.
:param params: (Optional) The query parameters to include in the API
call.
- :param data: (Optional) The data to include in the API call.
+ :param json: (Optional) The data to include in the API call.
:return: The response object from the API call.
"""
with requests.Session() as session:
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
index 4f214c681fb..84f6773b2bf 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/operators/snowflake.py
@@ -76,8 +76,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
Template references are recognized by str ending in '.sql'
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
- :param autocommit: if True, each command is automatically committed.
- (default value: True)
:param parameters: (optional) the parameters to render the SQL query with.
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
@@ -109,8 +107,6 @@ class SnowflakeCheckOperator(SQLCheckOperator):
sql: str,
snowflake_conn_id: str = "snowflake_default",
parameters: Iterable | Mapping[str, Any] | None = None,
- autocommit: bool = True,
- do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
@@ -179,8 +175,6 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
tolerance: Any = None,
snowflake_conn_id: str = "snowflake_default",
parameters: Iterable | Mapping[str, Any] | None = None,
- autocommit: bool = True,
- do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
@@ -202,7 +196,12 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
**hook_params,
}
super().__init__(
- sql=sql, pass_value=pass_value, tolerance=tolerance,
conn_id=snowflake_conn_id, **kwargs
+ sql=sql,
+ pass_value=pass_value,
+ tolerance=tolerance,
+ conn_id=snowflake_conn_id,
+ parameters=parameters,
+ **kwargs,
)
self.query_ids: list[str] = []
@@ -259,9 +258,6 @@ class
SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
date_filter_column: str = "ds",
days_back: SupportsAbs[int] = -7,
snowflake_conn_id: str = "snowflake_default",
- parameters: Iterable | Mapping[str, Any] | None = None,
- autocommit: bool = True,
- do_xcom_push: bool = True,
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
diff --git
a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
index 721b72e5781..c4711ccc26d 100644
--- a/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowflake.py
@@ -18,11 +18,13 @@
from __future__ import annotations
from unittest import mock
+from unittest.mock import call
import pendulum
import pytest
from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@@ -34,7 +36,7 @@ from airflow.providers.snowflake.operators.snowflake import (
SnowflakeValueCheckOperator,
)
from airflow.providers.snowflake.triggers.snowflake_trigger import
SnowflakeSqlApiTrigger
-from airflow.utils import timezone
+from airflow.utils import timezone # type: ignore[attr-defined]
from airflow.utils.types import DagRunType
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
@@ -107,25 +109,80 @@ class TestSnowflakeOperatorForParams:
)
[email protected](
- "operator_class, kwargs",
- [
- (SnowflakeCheckOperator, dict(sql="Select * from test_table")),
- (SnowflakeValueCheckOperator, dict(sql="Select * from test_table",
pass_value=95)),
- (SnowflakeIntervalCheckOperator, dict(table="test-table-id",
metrics_thresholds={"COUNT(*)": 1.5})),
- ],
-)
-class TestSnowflakeCheckOperators:
-
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook")
[email protected](autouse=True)
+def setup_connections(create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id="snowflake_default",
+ conn_type="snowflake",
+ host="test_host",
+ port=443,
+ schema="test_schema",
+ login="test_user",
+ password="test_password",
+ )
+ )
+
+
+class TestSnowflakeCheckOperator:
+
@mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckOperator.get_db_hook")
def test_get_db_hook(
self,
mock_get_db_hook,
- operator_class,
- kwargs,
):
- operator = operator_class(task_id="snowflake_check",
snowflake_conn_id="snowflake_default", **kwargs)
- operator.get_db_hook()
- mock_get_db_hook.assert_called_once()
+ operator = SnowflakeCheckOperator(
+ task_id="snowflake_check",
+ snowflake_conn_id="snowflake_default",
+ sql="Select * from test_table",
+ parameters={"param1": "value1"},
+ )
+ operator.execute({})
+ mock_get_db_hook.assert_has_calls(
+ [call().get_first("Select * from test_table", {"param1":
"value1"})]
+ )
+
+
+class TestSnowflakeValueCheckOperator:
+
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.get_db_hook")
+
@mock.patch("airflow.providers.common.sql.operators.sql.SQLValueCheckOperator.check_value")
+ def test_get_db_hook(
+ self,
+ mock_check_value,
+ mock_get_db_hook,
+ ):
+ mock_get_db_hook.return_value.get_first.return_value = ["test_value"]
+
+ operator = SnowflakeValueCheckOperator(
+ task_id="snowflake_check",
+ sql="Select * from test_table",
+ pass_value=95,
+ parameters={"param1": "value1"},
+ )
+ operator.execute({})
+ mock_get_db_hook.assert_has_calls(
+ [call().get_first("Select * from test_table", {"param1":
"value1"})]
+ )
+ assert mock_check_value.call_args == call(["test_value"])
+
+
+class TestSnowflakeIntervalCheckOperator:
+
@mock.patch("airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator.__init__")
+ def test_get_db_hook(
+ self,
+ mock_snowflake_interval_check_operator,
+ ):
+ SnowflakeIntervalCheckOperator(
+ task_id="snowflake_check", table="test-table-id",
metrics_thresholds={"COUNT(*)": 1.5}
+ )
+ assert mock_snowflake_interval_check_operator.call_args == mock.call(
+ table="test-table-id",
+ metrics_thresholds={"COUNT(*)": 1.5},
+ date_filter_column="ds",
+ days_back=-7,
+ conn_id="snowflake_default",
+ task_id="snowflake_check",
+ default_args={},
+ )
@pytest.mark.parametrize(