This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 891c2e4019 Add Deferrable switch to SnowflakeSqlApiOperator (#31596)
891c2e4019 is described below

commit 891c2e401928ecafea78f7c6c3b453663ef03dce
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Wed Jul 5 02:19:47 2023 +0530

    Add Deferrable switch to SnowflakeSqlApiOperator (#31596)
---
 .../providers/snowflake/hooks/snowflake_sql_api.py |  43 +++++---
 airflow/providers/snowflake/operators/snowflake.py |  46 +++++++-
 airflow/providers/snowflake/provider.yaml          |   5 +
 airflow/providers/snowflake/triggers/__init__.py   |  16 +++
 .../snowflake/triggers/snowflake_trigger.py        | 109 +++++++++++++++++++
 .../operators/snowflake.rst                        |   2 +
 .../snowflake/hooks/test_snowflake_sql_api.py      |  56 +++++++++-
 .../snowflake/operators/test_snowflake.py          |  96 ++++++++++++++++-
 tests/providers/snowflake/triggers/__init__.py     |  16 +++
 .../providers/snowflake/triggers/test_snowflake.py | 120 +++++++++++++++++++++
 10 files changed, 490 insertions(+), 19 deletions(-)

diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py 
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 0d808291ff..eec3c7349e 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -21,6 +21,7 @@ from datetime import timedelta
 from pathlib import Path
 from typing import Any
 
+import aiohttp
 import requests
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
@@ -59,7 +60,8 @@ class SnowflakeSqlApiHook(SnowflakeHook):
     :param session_parameters: You can set session-level parameters at
         the time you connect to Snowflake
     :param token_life_time: lifetime of the JWT Token in timedelta
-    :param token_renewal_delta: Renewal time of the JWT Token in  timedelta
+    :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     LIFETIME = timedelta(minutes=59)  # The tokens will have a 59 minute 
lifetime
@@ -225,17 +227,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
                     f"Response: {e.response.content}, Status Code: 
{e.response.status_code}"
                 )
 
-    def get_sql_api_query_status(self, query_id: str) -> dict[str, str | 
list[str]]:
-        """
-        Based on the query id async HTTP request is made to snowflake SQL API 
and return response.
-
-        :param query_id: statement handle id for the individual statements.
-        """
-        self.log.info("Retrieving status for query id %s", {query_id})
-        header, params, url = self.get_request_url_header_params(query_id)
-        response = requests.get(url, params=params, headers=header)
-        status_code = response.status_code
-        resp = response.json()
+    def _process_response(self, status_code, resp):
         self.log.info("Snowflake SQL GET statements status API response: %s", 
resp)
         if status_code == 202:
             return {"status": "running", "message": "Query statements are 
still running"}
@@ -254,3 +246,30 @@ class SnowflakeSqlApiHook(SnowflakeHook):
             }
         else:
             return {"status": "error", "message": resp["message"]}
+
+    def get_sql_api_query_status(self, query_id: str) -> dict[str, str | 
list[str]]:
+        """
+        Based on the query id async HTTP request is made to snowflake SQL API 
and return response.
+
+        :param query_id: statement handle id for the individual statements.
+        """
+        self.log.info("Retrieving status for query id %s", query_id)
+        header, params, url = self.get_request_url_header_params(query_id)
+        response = requests.get(url, params=params, headers=header)
+        status_code = response.status_code
+        resp = response.json()
+        return self._process_response(status_code, resp)
+
+    async def get_sql_api_query_status_async(self, query_id: str) -> dict[str, 
str | list[str]]:
+        """
+        Based on the query id async HTTP request is made to snowflake SQL API 
and return response.
+
+        :param query_id: statement handle id for the individual statements.
+        """
+        self.log.info("Retrieving status for query id %s", query_id)
+        header, params, url = self.get_request_url_header_params(query_id)
+        async with aiohttp.ClientSession(headers=header) as session:
+            async with session.get(url, params=params) as response:
+                status_code = response.status
+                resp = await response.json()
+                return self._process_response(status_code, resp)
diff --git a/airflow/providers/snowflake/operators/snowflake.py 
b/airflow/providers/snowflake/operators/snowflake.py
index b56b14b0d5..db35fa0007 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import time
 import warnings
 from datetime import timedelta
-from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
+from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence, 
SupportsAbs, cast
 
 from airflow import AirflowException
 from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -33,6 +33,7 @@ from airflow.providers.common.sql.operators.sql import (
 from airflow.providers.snowflake.hooks.snowflake_sql_api import (
     SnowflakeSqlApiHook,
 )
+from airflow.providers.snowflake.triggers.snowflake_trigger import 
SnowflakeSqlApiTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -430,6 +431,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
     :param bindings: (Optional) Values of bind variables in the SQL statement.
             When executing the statement, Snowflake replaces placeholders (? 
and :name) in
             the statement with these specified values.
+    :param deferrable: Run operator in the deferrable mode.
     """  # noqa
 
     LIFETIME = timedelta(minutes=59)  # The tokens will have a 59 minutes 
lifetime
@@ -450,6 +452,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
         token_life_time: timedelta = LIFETIME,
         token_renewal_delta: timedelta = RENEWAL_DELTA,
         bindings: dict[str, Any] | None = None,
+        deferrable: bool = False,
         **kwargs: Any,
     ) -> None:
         self.snowflake_conn_id = snowflake_conn_id
@@ -459,6 +462,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
         self.token_renewal_delta = token_renewal_delta
         self.bindings = bindings
         self.execute_async = False
+        self.deferrable = deferrable
         if any([warehouse, database, role, schema, authenticator, 
session_parameters]):  # pragma: no cover
             hook_params = kwargs.pop("hook_params", {})  # pragma: no cover
             kwargs["hook_params"] = {
@@ -482,6 +486,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
             snowflake_conn_id=self.snowflake_conn_id,
             token_life_time=self.token_life_time,
             token_renewal_delta=self.token_renewal_delta,
+            deferrable=self.deferrable,
         )
         self.query_ids = self._hook.execute_query(
             self.sql, statement_count=self.statement_count, 
bindings=self.bindings  # type: ignore[arg-type]
@@ -491,10 +496,23 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
         if self.do_xcom_push:
             context["ti"].xcom_push(key="query_ids", value=self.query_ids)
 
-        statement_status = self.poll_on_queries()
-        if statement_status["error"]:
-            raise AirflowException(statement_status["error"])
-        self._hook.check_query_output(self.query_ids)
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=SnowflakeSqlApiTrigger(
+                    poll_interval=self.poll_interval,
+                    query_ids=self.query_ids,
+                    snowflake_conn_id=self.snowflake_conn_id,
+                    token_life_time=self.token_life_time,
+                    token_renewal_delta=self.token_renewal_delta,
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            statement_status = self.poll_on_queries()
+            if statement_status["error"]:
+                raise AirflowException(statement_status["error"])
+            self._hook.check_query_output(self.query_ids)
 
     def poll_on_queries(self):
         """Poll on requested queries."""
@@ -517,3 +535,21 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
                 queries_in_progress.remove(query_id)
             time.sleep(self.poll_interval)
         return {"success": statement_success_status, "error": 
statement_error_status}
+
+    def execute_complete(self, context: Context, event: dict[str, str | 
list[str]] | None = None) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event:
+            if "status" in event and event["status"] == "error":
+                msg = f"{event['status']}: {event['message']}"
+                raise AirflowException(msg)
+            elif "status" in event and event["status"] == "success":
+                hook = 
SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
+                query_ids = cast(List[str], event["statement_query_ids"])
+                hook.check_query_output(query_ids)
+                self.log.info("%s completed successfully.", self.task_id)
+        else:
+            self.log.info("%s completed successfully.", self.task_id)
diff --git a/airflow/providers/snowflake/provider.yaml 
b/airflow/providers/snowflake/provider.yaml
index 1e68fbddca..2cea953ab4 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -100,3 +100,8 @@ transfers:
 connection-types:
   - hook-class-name: airflow.providers.snowflake.hooks.snowflake.SnowflakeHook
     connection-type: snowflake
+
+triggers:
+  - integration-name: Snowflake
+    python-modules:
+      - airflow.providers.snowflake.triggers.snowflake_trigger
diff --git a/airflow/providers/snowflake/triggers/__init__.py 
b/airflow/providers/snowflake/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/snowflake/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/snowflake/triggers/snowflake_trigger.py 
b/airflow/providers/snowflake/triggers/snowflake_trigger.py
new file mode 100644
index 0000000000..4f1e0cffb2
--- /dev/null
+++ b/airflow/providers/snowflake/triggers/snowflake_trigger.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from datetime import timedelta
+from typing import Any, AsyncIterator
+
+from airflow.providers.snowflake.hooks.snowflake_sql_api import 
SnowflakeSqlApiHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class SnowflakeSqlApiTrigger(BaseTrigger):
+    """
+    Fetch the status for the query ids passed.
+
+    :param poll_interval:  polling period in seconds to check for the status
+    :param query_ids: List of Query ids to run and poll for the status
+    :param snowflake_conn_id: Reference to Snowflake connection id
+    :param token_life_time: lifetime of the JWT Token in timedelta
+    :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+    """
+
+    def __init__(
+        self,
+        poll_interval: float,
+        query_ids: list[str],
+        snowflake_conn_id: str,
+        token_life_time: timedelta,
+        token_renewal_delta: timedelta,
+    ):
+        super().__init__()
+        self.poll_interval = poll_interval
+        self.query_ids = query_ids
+        self.snowflake_conn_id = snowflake_conn_id
+        self.token_life_time = token_life_time
+        self.token_renewal_delta = token_renewal_delta
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes SnowflakeSqlApiTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger",
+            {
+                "poll_interval": self.poll_interval,
+                "query_ids": self.query_ids,
+                "snowflake_conn_id": self.snowflake_conn_id,
+                "token_life_time": self.token_life_time,
+                "token_renewal_delta": self.token_renewal_delta,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Wait for the query the snowflake query to complete."""
+        SnowflakeSqlApiHook(
+            self.snowflake_conn_id,
+            self.token_life_time,
+            self.token_renewal_delta,
+        )
+        try:
+            statement_query_ids: list[str] = []
+            for query_id in self.query_ids:
+                while True:
+                    statement_status = await self.get_query_status(query_id)
+                    if statement_status["status"] not in ["running"]:
+                        break
+                    await asyncio.sleep(self.poll_interval)
+                if statement_status["status"] == "error":
+                    yield TriggerEvent(statement_status)
+                    return
+                if statement_status["status"] == "success":
+                    
statement_query_ids.extend(statement_status["statement_handles"])
+            yield TriggerEvent(
+                {
+                    "status": "success",
+                    "statement_query_ids": statement_query_ids,
+                }
+            )
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})
+
+    async def get_query_status(self, query_id: str) -> dict[str, Any]:
+        """
+        Async function to check whether the query statement submitted via SQL 
API is still
+        running state and returns True if it is still running else
+        return False.
+        """
+        hook = SnowflakeSqlApiHook(
+            self.snowflake_conn_id,
+            self.token_life_time,
+            self.token_renewal_delta,
+        )
+        return await hook.get_sql_api_query_status_async(query_id)
+
+    def _set_context(self, context):
+        pass
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst 
b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
index 1e80f3af29..d3ffbec5a4 100644
--- a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+++ b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
@@ -66,6 +66,8 @@ SnowflakeSqlApiOperator
 Use the :class:`SnowflakeSqlApiHook 
<airflow.providers.snowflake.operators.snowflake>` to execute
 SQL commands in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
 
+You can also run this operator in deferrable mode by setting ``deferrable`` 
param to ``True``.
+This will ensure that the task is deferred from the Airflow worker slot and 
polling for the task status happens on the trigger.
 
 Using the Operator
 ^^^^^^^^^^^^^^^^^^
diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py 
b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
index 61e88de864..fd2da72c92 100644
--- a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
+++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
@@ -21,6 +21,7 @@ import uuid
 from pathlib import Path
 from typing import Any
 from unittest import mock
+from unittest.mock import AsyncMock
 
 import pytest
 import requests
@@ -396,7 +397,6 @@ class TestSnowflakeSqlApiHook:
         ), pytest.raises(TypeError, match="Password was given but private key 
is not encrypted."):
             
SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key()
 
-    @pytest.mark.asyncio
     @pytest.mark.parametrize(
         "status_code,response,expected_response",
         [
@@ -456,3 +456,57 @@ class TestSnowflakeSqlApiHook:
         mock_requests.get.return_value = MockResponse(status_code, response)
         hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
         assert hook.get_sql_api_query_status("uuid") == expected_response
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "status_code,response,expected_response",
+        [
+            (
+                200,
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statementHandle": "uuid",
+                },
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statement_handles": ["uuid"],
+                },
+            ),
+            (
+                200,
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statementHandles": ["uuid", "uuid1"],
+                },
+                {
+                    "status": "success",
+                    "message": "Statement executed successfully.",
+                    "statement_handles": ["uuid", "uuid1"],
+                },
+            ),
+            (202, {}, {"status": "running", "message": "Query statements are 
still running"}),
+            (422, {"status": "error", "message": "test"}, {"status": "error", 
"message": "test"}),
+            (404, {"status": "error", "message": "test"}, {"status": "error", 
"message": "test"}),
+        ],
+    )
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+        "get_request_url_header_params"
+    )
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get")
+    async def test_get_sql_api_query_status_async(
+        self, mock_get, mock_geturl_header_params, status_code, response, 
expected_response
+    ):
+        """Test Async get_sql_api_query_status_async function by mocking the 
status,
+        response and expected response"""
+        req_id = uuid.uuid4()
+        params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+        mock_geturl_header_params.return_value = HEADERS, params, 
"/test/airflow/"
+        mock_get.return_value.__aenter__.return_value.status = status_code
+        mock_get.return_value.__aenter__.return_value.json = 
AsyncMock(return_value=response)
+        hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+        response = await hook.get_sql_api_query_status_async("uuid")
+        assert response == expected_response
diff --git a/tests/providers/snowflake/operators/test_snowflake.py 
b/tests/providers/snowflake/operators/test_snowflake.py
index 8f32c6e62d..41cbfe6717 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -19,10 +19,13 @@ from __future__ import annotations
 
 from unittest import mock
 
+import pendulum
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.models.dag import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
 from airflow.providers.snowflake.operators.snowflake import (
     SnowflakeCheckOperator,
     SnowflakeIntervalCheckOperator,
@@ -30,7 +33,9 @@ from airflow.providers.snowflake.operators.snowflake import (
     SnowflakeSqlApiOperator,
     SnowflakeValueCheckOperator,
 )
+from airflow.providers.snowflake.triggers.snowflake_trigger import 
SnowflakeSqlApiTrigger
 from airflow.utils import timezone
+from airflow.utils.types import DagRunType
 
 DEFAULT_DATE = timezone.datetime(2015, 1, 1)
 DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
@@ -88,6 +93,34 @@ class TestSnowflakeCheckOperators:
         mock_get_db_hook.assert_called_once()
 
 
+def create_context(task, dag=None):
+    if dag is None:
+        dag = DAG(dag_id="dag")
+    tzinfo = pendulum.timezone("UTC")
+    execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+    dag_run = DagRun(
+        dag_id=dag.dag_id,
+        execution_date=execution_date,
+        run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+    )
+
+    task_instance = TaskInstance(task=task)
+    task_instance.dag_run = dag_run
+    task_instance.xcom_push = mock.Mock()
+    return {
+        "dag": dag,
+        "ts": execution_date.isoformat(),
+        "task": task,
+        "ti": task_instance,
+        "task_instance": task_instance,
+        "run_id": dag_run.run_id,
+        "dag_run": dag_run,
+        "execution_date": execution_date,
+        "data_interval_end": execution_date,
+        "logical_date": execution_date,
+    }
+
+
 class TestSnowflakeSqlApiOperator:
     @pytest.fixture
     def mock_execute_query(self):
@@ -142,3 +175,64 @@ class TestSnowflakeSqlApiOperator:
         mock_get_sql_api_query_status.side_effect = [{"status": "error"}, 
{"status": "success"}]
         with pytest.raises(AirflowException):
             operator.execute(context=None)
+
+    @pytest.mark.parametrize("mock_sql, statement_count", 
[(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)])
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query")
+    def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook, 
mock_sql, statement_count):
+        """
+        Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be 
fired
+        when the SnowflakeSqlApiOperator is executed.
+        """
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            sql=mock_sql,
+            statement_count=statement_count,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            operator.execute(create_context(operator))
+
+        assert isinstance(
+            exc.value.trigger, SnowflakeSqlApiTrigger
+        ), "Trigger is not a SnowflakeSqlApiTrigger"
+
+    def test_snowflake_sql_api_execute_complete_failure(self):
+        """Test SnowflakeSqlApiOperator raise AirflowException of error 
event"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            deferrable=True,
+        )
+        with pytest.raises(AirflowException):
+            operator.execute_complete(
+                context=None,
+                event={"status": "error", "message": "Test failure message", 
"type": "FAILED_WITH_ERROR"},
+            )
+
+    @pytest.mark.parametrize(
+        "mock_event",
+        [
+            None,
+            ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}),
+        ],
+    )
+    
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output")
+    def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event):
+        """Tests execute_complete assert with successful message"""
+
+        operator = SnowflakeSqlApiOperator(
+            task_id=TASK_ID,
+            snowflake_conn_id=CONN_ID,
+            sql=SQL_MULTIPLE_STMTS,
+            statement_count=4,
+            deferrable=True,
+        )
+
+        with mock.patch.object(operator.log, "info") as mock_log_info:
+            operator.execute_complete(context=None, event=mock_event)
+        mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
diff --git a/tests/providers/snowflake/triggers/__init__.py 
b/tests/providers/snowflake/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/snowflake/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/snowflake/triggers/test_snowflake.py 
b/tests/providers/snowflake/triggers/test_snowflake.py
new file mode 100644
index 0000000000..9fc1459162
--- /dev/null
+++ b/tests/providers/snowflake/triggers/test_snowflake.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from datetime import timedelta
+from unittest import mock
+
+import pytest
+
+from airflow.providers.snowflake.triggers.snowflake_trigger import 
SnowflakeSqlApiTrigger
+from airflow.triggers.base import TriggerEvent
+
+TASK_ID = "snowflake_check"
+POLL_INTERVAL = 1.0
+LIFETIME = timedelta(minutes=59)
+RENEWAL_DELTA = timedelta(minutes=54)
+MODULE = "airflow.providers.snowflake"
+
+
+class TestSnowflakeSqlApiTrigger:
+    TRIGGER = SnowflakeSqlApiTrigger(
+        poll_interval=POLL_INTERVAL,
+        query_ids=["uuid"],
+        snowflake_conn_id="test_conn",
+        token_life_time=LIFETIME,
+        token_renewal_delta=RENEWAL_DELTA,
+    )
+
+    def test_snowflake_sql_trigger_serialization(self):
+        """
+        Asserts that the SnowflakeSqlApiTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+        classpath, kwargs = self.TRIGGER.serialize()
+        assert classpath == 
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger"
+        assert kwargs == {
+            "poll_interval": POLL_INTERVAL,
+            "query_ids": ["uuid"],
+            "snowflake_conn_id": "test_conn",
+            "token_life_time": LIFETIME,
+            "token_renewal_delta": RENEWAL_DELTA,
+        }
+
+    @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status")
+    
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+    async def test_snowflake_sql_trigger_running(
+        self, mock_get_sql_api_query_status_async, mock_get_query_status
+    ):
+        """Tests that the SnowflakeSqlApiTrigger in running by mocking 
get_query_status to true"""
+        mock_get_query_status.return_value = {"status": "running"}
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was not returned
+        assert task.done() is False
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status")
+    
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+    async def test_snowflake_sql_trigger_completed(
+        self, mock_get_sql_api_query_status_async, mock_get_query_status
+    ):
+        """
+        Test SnowflakeSqlApiTrigger run method with success status and mock 
the get_sql_api_query_status
+         result and  get_query_status to False.
+        """
+        mock_get_query_status.return_value = {"status": "success", 
"statement_handles": ["uuid", "uuid1"]}
+        statement_query_ids = ["uuid", "uuid1"]
+        mock_get_sql_api_query_status_async.return_value = {
+            "message": "Statement executed successfully.",
+            "status": "success",
+            "statement_handles": statement_query_ids,
+        }
+
+        generator = self.TRIGGER.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent({"status": "success", "statement_query_ids": 
statement_query_ids}) == actual
+
+    @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+    async def test_snowflake_sql_trigger_failure_status(self, 
mock_get_sql_api_query_status_async):
+        """Test SnowflakeSqlApiTrigger task is executed and triggered with 
failure status."""
+        mock_response = {
+            "status": "error",
+            "message": "An error occurred when executing the statement. Check "
+            "the error code and error message for details",
+        }
+        mock_get_sql_api_query_status_async.return_value = mock_response
+
+        generator = self.TRIGGER.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent(mock_response) == actual
+
+    @pytest.mark.asyncio
+    
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+    async def test_snowflake_sql_trigger_exception(self, 
mock_get_sql_api_query_status_async):
+        """Tests the SnowflakeSqlApiTrigger does not fire if there is an 
exception."""
+        mock_get_sql_api_query_status_async.side_effect = Exception("Test 
exception")
+
+        task = [i async for i in self.TRIGGER.run()]
+        assert len(task) == 1
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) 
in task

Reply via email to