This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ae98b824d Use unused SQLCheckOperator.parameters in 
SQLCheckOperator.execute. (#27599)
3ae98b824d is described below

commit 3ae98b824db437b2db928a73ac8b50c0a2f80124
Author: Wil Molina <[email protected]>
AuthorDate: Mon Nov 14 11:32:50 2022 -0800

    Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)
---
 airflow/providers/common/sql/operators/sql.py      | 12 ++++++++++--
 airflow/providers/snowflake/operators/snowflake.py |  2 +-
 tests/providers/common/sql/operators/test_sql.py   |  7 ++++++-
 3 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/common/sql/operators/sql.py 
b/airflow/providers/common/sql/operators/sql.py
index 66244a858d..8dee6ed968 100644
--- a/airflow/providers/common/sql/operators/sql.py
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -608,6 +608,7 @@ class SQLCheckOperator(BaseSQLOperator):
     :param sql: the sql to be executed. (templated)
     :param conn_id: the connection ID used to connect to the database.
     :param database: name of database which overwrite the defined one in 
connection
+    :param parameters: (optional) the parameters to render the SQL query with.
     """
 
     template_fields: Sequence[str] = ("sql",)
@@ -619,14 +620,21 @@ class SQLCheckOperator(BaseSQLOperator):
     ui_color = "#fff7e6"
 
     def __init__(
-        self, *, sql: str, conn_id: str | None = None, database: str | None = 
None, **kwargs
+        self,
+        *,
+        sql: str,
+        conn_id: str | None = None,
+        database: str | None = None,
+        parameters: Iterable | Mapping | None = None,
+        **kwargs,
     ) -> None:
         super().__init__(conn_id=conn_id, database=database, **kwargs)
         self.sql = sql
+        self.parameters = parameters
 
     def execute(self, context: Context):
         self.log.info("Executing SQL check: %s", self.sql)
-        records = self.get_db_hook().get_first(self.sql)
+        records = self.get_db_hook().get_first(self.sql, self.parameters)
 
         self.log.info("Record: %s", records)
         if not records:
diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index 2546ddfb5e..cf7835ef65 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -179,7 +179,7 @@ class SnowflakeCheckOperator(SQLCheckOperator):
         session_parameters: dict | None = None,
         **kwargs,
     ) -> None:
-        super().__init__(sql=sql, **kwargs)
+        super().__init__(sql=sql, parameters=parameters, **kwargs)
         self.snowflake_conn_id = snowflake_conn_id
         self.sql = sql
         self.autocommit = autocommit
diff --git a/tests/providers/common/sql/operators/test_sql.py 
b/tests/providers/common/sql/operators/test_sql.py
index 51f013f7fc..3741a93ed3 100644
--- a/tests/providers/common/sql/operators/test_sql.py
+++ b/tests/providers/common/sql/operators/test_sql.py
@@ -483,7 +483,7 @@ class TestSQLCheckOperatorDbHook:
 
 class TestCheckOperator(unittest.TestCase):
     def setUp(self):
-        self._operator = SQLCheckOperator(task_id="test_task", sql="sql")
+        self._operator = SQLCheckOperator(task_id="test_task", sql="sql", 
parameters="parameters")
 
     @mock.patch.object(SQLCheckOperator, "get_db_hook")
     def test_execute_no_records(self, mock_get_db_hook):
@@ -499,6 +499,11 @@ class TestCheckOperator(unittest.TestCase):
         with pytest.raises(AirflowException, match=r"Test failed."):
             self._operator.execute({})
 
+    @mock.patch.object(SQLCheckOperator, "get_db_hook")
+    def test_sqlcheckoperator_parameters(self, mock_get_db_hook):
+        self._operator.execute({})
+        mock_get_db_hook.return_value.get_first.assert_called_once_with("sql", 
"parameters")
+
 
 class TestValueCheckOperator(unittest.TestCase):
     def setUp(self):

Reply via email to