This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 891c2e4019 Add Deferrable switch to SnowflakeSqlApiOperator (#31596)
891c2e4019 is described below
commit 891c2e401928ecafea78f7c6c3b453663ef03dce
Author: Utkarsh Sharma <[email protected]>
AuthorDate: Wed Jul 5 02:19:47 2023 +0530
Add Deferrable switch to SnowflakeSqlApiOperator (#31596)
---
.../providers/snowflake/hooks/snowflake_sql_api.py | 43 +++++---
airflow/providers/snowflake/operators/snowflake.py | 46 +++++++-
airflow/providers/snowflake/provider.yaml | 5 +
airflow/providers/snowflake/triggers/__init__.py | 16 +++
.../snowflake/triggers/snowflake_trigger.py | 109 +++++++++++++++++++
.../operators/snowflake.rst | 2 +
.../snowflake/hooks/test_snowflake_sql_api.py | 56 +++++++++-
.../snowflake/operators/test_snowflake.py | 96 ++++++++++++++++-
tests/providers/snowflake/triggers/__init__.py | 16 +++
.../providers/snowflake/triggers/test_snowflake.py | 120 +++++++++++++++++++++
10 files changed, 490 insertions(+), 19 deletions(-)
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 0d808291ff..eec3c7349e 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -21,6 +21,7 @@ from datetime import timedelta
from pathlib import Path
from typing import Any
+import aiohttp
import requests
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
@@ -59,7 +60,8 @@ class SnowflakeSqlApiHook(SnowflakeHook):
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param token_life_time: lifetime of the JWT Token in timedelta
- :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+ :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+ :param deferrable: Run operator in the deferrable mode.
"""
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minute
lifetime
@@ -225,17 +227,7 @@ class SnowflakeSqlApiHook(SnowflakeHook):
f"Response: {e.response.content}, Status Code:
{e.response.status_code}"
)
- def get_sql_api_query_status(self, query_id: str) -> dict[str, str |
list[str]]:
- """
- Based on the query id async HTTP request is made to snowflake SQL API
and return response.
-
- :param query_id: statement handle id for the individual statements.
- """
- self.log.info("Retrieving status for query id %s", {query_id})
- header, params, url = self.get_request_url_header_params(query_id)
- response = requests.get(url, params=params, headers=header)
- status_code = response.status_code
- resp = response.json()
+ def _process_response(self, status_code, resp):
self.log.info("Snowflake SQL GET statements status API response: %s",
resp)
if status_code == 202:
return {"status": "running", "message": "Query statements are
still running"}
@@ -254,3 +246,30 @@ class SnowflakeSqlApiHook(SnowflakeHook):
}
else:
return {"status": "error", "message": resp["message"]}
+
+ def get_sql_api_query_status(self, query_id: str) -> dict[str, str |
list[str]]:
+ """
+ Based on the query id async HTTP request is made to snowflake SQL API
and return response.
+
+ :param query_id: statement handle id for the individual statements.
+ """
+ self.log.info("Retrieving status for query id %s", query_id)
+ header, params, url = self.get_request_url_header_params(query_id)
+ response = requests.get(url, params=params, headers=header)
+ status_code = response.status_code
+ resp = response.json()
+ return self._process_response(status_code, resp)
+
+ async def get_sql_api_query_status_async(self, query_id: str) -> dict[str,
str | list[str]]:
+ """
+ Based on the query id async HTTP request is made to snowflake SQL API
and return response.
+
+ :param query_id: statement handle id for the individual statements.
+ """
+ self.log.info("Retrieving status for query id %s", query_id)
+ header, params, url = self.get_request_url_header_params(query_id)
+ async with aiohttp.ClientSession(headers=header) as session:
+ async with session.get(url, params=params) as response:
+ status_code = response.status
+ resp = await response.json()
+ return self._process_response(status_code, resp)
diff --git a/airflow/providers/snowflake/operators/snowflake.py
b/airflow/providers/snowflake/operators/snowflake.py
index b56b14b0d5..db35fa0007 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import time
import warnings
from datetime import timedelta
-from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, SupportsAbs
+from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Sequence,
SupportsAbs, cast
from airflow import AirflowException
from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -33,6 +33,7 @@ from airflow.providers.common.sql.operators.sql import (
from airflow.providers.snowflake.hooks.snowflake_sql_api import (
SnowflakeSqlApiHook,
)
+from airflow.providers.snowflake.triggers.snowflake_trigger import
SnowflakeSqlApiTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -430,6 +431,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
:param bindings: (Optional) Values of bind variables in the SQL statement.
When executing the statement, Snowflake replaces placeholders (?
and :name) in
the statement with these specified values.
+ :param deferrable: Run operator in the deferrable mode.
""" # noqa
LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes
lifetime
@@ -450,6 +452,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
bindings: dict[str, Any] | None = None,
+ deferrable: bool = False,
**kwargs: Any,
) -> None:
self.snowflake_conn_id = snowflake_conn_id
@@ -459,6 +462,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
+ self.deferrable = deferrable
if any([warehouse, database, role, schema, authenticator,
session_parameters]): # pragma: no cover
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
kwargs["hook_params"] = {
@@ -482,6 +486,7 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
+ deferrable=self.deferrable,
)
self.query_ids = self._hook.execute_query(
self.sql, statement_count=self.statement_count,
bindings=self.bindings # type: ignore[arg-type]
@@ -491,10 +496,23 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)
- statement_status = self.poll_on_queries()
- if statement_status["error"]:
- raise AirflowException(statement_status["error"])
- self._hook.check_query_output(self.query_ids)
+ if self.deferrable:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=SnowflakeSqlApiTrigger(
+ poll_interval=self.poll_interval,
+ query_ids=self.query_ids,
+ snowflake_conn_id=self.snowflake_conn_id,
+ token_life_time=self.token_life_time,
+ token_renewal_delta=self.token_renewal_delta,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ statement_status = self.poll_on_queries()
+ if statement_status["error"]:
+ raise AirflowException(statement_status["error"])
+ self._hook.check_query_output(self.query_ids)
def poll_on_queries(self):
"""Poll on requested queries."""
@@ -517,3 +535,21 @@ class SnowflakeSqlApiOperator(SQLExecuteQueryOperator):
queries_in_progress.remove(query_id)
time.sleep(self.poll_interval)
return {"success": statement_success_status, "error":
statement_error_status}
+
+ def execute_complete(self, context: Context, event: dict[str, str |
list[str]] | None = None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event:
+ if "status" in event and event["status"] == "error":
+ msg = f"{event['status']}: {event['message']}"
+ raise AirflowException(msg)
+ elif "status" in event and event["status"] == "success":
+ hook =
SnowflakeSqlApiHook(snowflake_conn_id=self.snowflake_conn_id)
+ query_ids = cast(List[str], event["statement_query_ids"])
+ hook.check_query_output(query_ids)
+ self.log.info("%s completed successfully.", self.task_id)
+ else:
+ self.log.info("%s completed successfully.", self.task_id)
diff --git a/airflow/providers/snowflake/provider.yaml
b/airflow/providers/snowflake/provider.yaml
index 1e68fbddca..2cea953ab4 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -100,3 +100,8 @@ transfers:
connection-types:
- hook-class-name: airflow.providers.snowflake.hooks.snowflake.SnowflakeHook
connection-type: snowflake
+
+triggers:
+ - integration-name: Snowflake
+ python-modules:
+ - airflow.providers.snowflake.triggers.snowflake_trigger
diff --git a/airflow/providers/snowflake/triggers/__init__.py
b/airflow/providers/snowflake/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/snowflake/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/snowflake/triggers/snowflake_trigger.py
b/airflow/providers/snowflake/triggers/snowflake_trigger.py
new file mode 100644
index 0000000000..4f1e0cffb2
--- /dev/null
+++ b/airflow/providers/snowflake/triggers/snowflake_trigger.py
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from datetime import timedelta
+from typing import Any, AsyncIterator
+
+from airflow.providers.snowflake.hooks.snowflake_sql_api import
SnowflakeSqlApiHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class SnowflakeSqlApiTrigger(BaseTrigger):
+ """
+ Fetch the status for the query ids passed.
+
+ :param poll_interval: polling period in seconds to check for the status
+ :param query_ids: List of Query ids to run and poll for the status
+ :param snowflake_conn_id: Reference to Snowflake connection id
+ :param token_life_time: lifetime of the JWT Token in timedelta
+ :param token_renewal_delta: Renewal time of the JWT Token in timedelta
+ """
+
+ def __init__(
+ self,
+ poll_interval: float,
+ query_ids: list[str],
+ snowflake_conn_id: str,
+ token_life_time: timedelta,
+ token_renewal_delta: timedelta,
+ ):
+ super().__init__()
+ self.poll_interval = poll_interval
+ self.query_ids = query_ids
+ self.snowflake_conn_id = snowflake_conn_id
+ self.token_life_time = token_life_time
+ self.token_renewal_delta = token_renewal_delta
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes SnowflakeSqlApiTrigger arguments and classpath."""
+ return (
+
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger",
+ {
+ "poll_interval": self.poll_interval,
+ "query_ids": self.query_ids,
+ "snowflake_conn_id": self.snowflake_conn_id,
+ "token_life_time": self.token_life_time,
+ "token_renewal_delta": self.token_renewal_delta,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Wait for the query the snowflake query to complete."""
+ SnowflakeSqlApiHook(
+ self.snowflake_conn_id,
+ self.token_life_time,
+ self.token_renewal_delta,
+ )
+ try:
+ statement_query_ids: list[str] = []
+ for query_id in self.query_ids:
+ while True:
+ statement_status = await self.get_query_status(query_id)
+ if statement_status["status"] not in ["running"]:
+ break
+ await asyncio.sleep(self.poll_interval)
+ if statement_status["status"] == "error":
+ yield TriggerEvent(statement_status)
+ return
+ if statement_status["status"] == "success":
+
statement_query_ids.extend(statement_status["statement_handles"])
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "statement_query_ids": statement_query_ids,
+ }
+ )
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
+
+ async def get_query_status(self, query_id: str) -> dict[str, Any]:
+ """
+ Async function to check whether the query statement submitted via SQL
API is still
+ running state and returns True if it is still running else
+ return False.
+ """
+ hook = SnowflakeSqlApiHook(
+ self.snowflake_conn_id,
+ self.token_life_time,
+ self.token_renewal_delta,
+ )
+ return await hook.get_sql_api_query_status_async(query_id)
+
+ def _set_context(self, context):
+ pass
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
index 1e80f3af29..d3ffbec5a4 100644
--- a/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+++ b/docs/apache-airflow-providers-snowflake/operators/snowflake.rst
@@ -66,6 +66,8 @@ SnowflakeSqlApiOperator
Use the :class:`SnowflakeSqlApiHook
<airflow.providers.snowflake.operators.snowflake>` to execute
SQL commands in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+You can also run this operator in deferrable mode by setting ``deferrable``
param to ``True``.
+This will ensure that the task is deferred from the Airflow worker slot and
polling for the task status happens on the trigger.
Using the Operator
^^^^^^^^^^^^^^^^^^
diff --git a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
index 61e88de864..fd2da72c92 100644
--- a/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
+++ b/tests/providers/snowflake/hooks/test_snowflake_sql_api.py
@@ -21,6 +21,7 @@ import uuid
from pathlib import Path
from typing import Any
from unittest import mock
+from unittest.mock import AsyncMock
import pytest
import requests
@@ -396,7 +397,6 @@ class TestSnowflakeSqlApiHook:
), pytest.raises(TypeError, match="Password was given but private key
is not encrypted."):
SnowflakeSqlApiHook(snowflake_conn_id="test_conn").get_private_key()
- @pytest.mark.asyncio
@pytest.mark.parametrize(
"status_code,response,expected_response",
[
@@ -456,3 +456,57 @@ class TestSnowflakeSqlApiHook:
mock_requests.get.return_value = MockResponse(status_code, response)
hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
assert hook.get_sql_api_query_status("uuid") == expected_response
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "status_code,response,expected_response",
+ [
+ (
+ 200,
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statementHandle": "uuid",
+ },
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statement_handles": ["uuid"],
+ },
+ ),
+ (
+ 200,
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statementHandles": ["uuid", "uuid1"],
+ },
+ {
+ "status": "success",
+ "message": "Statement executed successfully.",
+ "statement_handles": ["uuid", "uuid1"],
+ },
+ ),
+ (202, {}, {"status": "running", "message": "Query statements are
still running"}),
+ (422, {"status": "error", "message": "test"}, {"status": "error",
"message": "test"}),
+ (404, {"status": "error", "message": "test"}, {"status": "error",
"message": "test"}),
+ ],
+ )
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook."
+ "get_request_url_header_params"
+ )
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.aiohttp.ClientSession.get")
+ async def test_get_sql_api_query_status_async(
+ self, mock_get, mock_geturl_header_params, status_code, response,
expected_response
+ ):
+ """Test Async get_sql_api_query_status_async function by mocking the
status,
+ response and expected response"""
+ req_id = uuid.uuid4()
+ params = {"requestId": str(req_id), "page": 2, "pageSize": 10}
+ mock_geturl_header_params.return_value = HEADERS, params,
"/test/airflow/"
+ mock_get.return_value.__aenter__.return_value.status = status_code
+ mock_get.return_value.__aenter__.return_value.json =
AsyncMock(return_value=response)
+ hook = SnowflakeSqlApiHook(snowflake_conn_id="test_conn")
+ response = await hook.get_sql_api_query_status_async("uuid")
+ assert response == expected_response
diff --git a/tests/providers/snowflake/operators/test_snowflake.py
b/tests/providers/snowflake/operators/test_snowflake.py
index 8f32c6e62d..41cbfe6717 100644
--- a/tests/providers/snowflake/operators/test_snowflake.py
+++ b/tests/providers/snowflake/operators/test_snowflake.py
@@ -19,10 +19,13 @@ from __future__ import annotations
from unittest import mock
+import pendulum
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models.dag import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.snowflake.operators.snowflake import (
SnowflakeCheckOperator,
SnowflakeIntervalCheckOperator,
@@ -30,7 +33,9 @@ from airflow.providers.snowflake.operators.snowflake import (
SnowflakeSqlApiOperator,
SnowflakeValueCheckOperator,
)
+from airflow.providers.snowflake.triggers.snowflake_trigger import
SnowflakeSqlApiTrigger
from airflow.utils import timezone
+from airflow.utils.types import DagRunType
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
@@ -88,6 +93,34 @@ class TestSnowflakeCheckOperators:
mock_get_db_hook.assert_called_once()
+def create_context(task, dag=None):
+ if dag is None:
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("UTC")
+ execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=execution_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+ )
+
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "ts": execution_date.isoformat(),
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
+ "run_id": dag_run.run_id,
+ "dag_run": dag_run,
+ "execution_date": execution_date,
+ "data_interval_end": execution_date,
+ "logical_date": execution_date,
+ }
+
+
class TestSnowflakeSqlApiOperator:
@pytest.fixture
def mock_execute_query(self):
@@ -142,3 +175,64 @@ class TestSnowflakeSqlApiOperator:
mock_get_sql_api_query_status.side_effect = [{"status": "error"},
{"status": "success"}]
with pytest.raises(AirflowException):
operator.execute(context=None)
+
+ @pytest.mark.parametrize("mock_sql, statement_count",
[(SQL_MULTIPLE_STMTS, 4), (SINGLE_STMT, 1)])
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.execute_query")
+ def test_snowflake_sql_api_execute_operator_async(self, mock_db_hook,
mock_sql, statement_count):
+ """
+ Asserts that a task is deferred and an SnowflakeSqlApiTrigger will be
fired
+ when the SnowflakeSqlApiOperator is executed.
+ """
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ sql=mock_sql,
+ statement_count=statement_count,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(create_context(operator))
+
+ assert isinstance(
+ exc.value.trigger, SnowflakeSqlApiTrigger
+ ), "Trigger is not a SnowflakeSqlApiTrigger"
+
+ def test_snowflake_sql_api_execute_complete_failure(self):
+ """Test SnowflakeSqlApiOperator raise AirflowException of error
event"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException):
+ operator.execute_complete(
+ context=None,
+ event={"status": "error", "message": "Test failure message",
"type": "FAILED_WITH_ERROR"},
+ )
+
+ @pytest.mark.parametrize(
+ "mock_event",
+ [
+ None,
+ ({"status": "success", "statement_query_ids": ["uuid", "uuid"]}),
+ ],
+ )
+
@mock.patch("airflow.providers.snowflake.hooks.snowflake_sql_api.SnowflakeSqlApiHook.check_query_output")
+ def test_snowflake_sql_api_execute_complete(self, mock_conn, mock_event):
+ """Tests execute_complete assert with successful message"""
+
+ operator = SnowflakeSqlApiOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ sql=SQL_MULTIPLE_STMTS,
+ statement_count=4,
+ deferrable=True,
+ )
+
+ with mock.patch.object(operator.log, "info") as mock_log_info:
+ operator.execute_complete(context=None, event=mock_event)
+ mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
diff --git a/tests/providers/snowflake/triggers/__init__.py
b/tests/providers/snowflake/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/snowflake/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/snowflake/triggers/test_snowflake.py
b/tests/providers/snowflake/triggers/test_snowflake.py
new file mode 100644
index 0000000000..9fc1459162
--- /dev/null
+++ b/tests/providers/snowflake/triggers/test_snowflake.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from datetime import timedelta
+from unittest import mock
+
+import pytest
+
+from airflow.providers.snowflake.triggers.snowflake_trigger import
SnowflakeSqlApiTrigger
+from airflow.triggers.base import TriggerEvent
+
+TASK_ID = "snowflake_check"
+POLL_INTERVAL = 1.0
+LIFETIME = timedelta(minutes=59)
+RENEWAL_DELTA = timedelta(minutes=54)
+MODULE = "airflow.providers.snowflake"
+
+
+class TestSnowflakeSqlApiTrigger:
+ TRIGGER = SnowflakeSqlApiTrigger(
+ poll_interval=POLL_INTERVAL,
+ query_ids=["uuid"],
+ snowflake_conn_id="test_conn",
+ token_life_time=LIFETIME,
+ token_renewal_delta=RENEWAL_DELTA,
+ )
+
+ def test_snowflake_sql_trigger_serialization(self):
+ """
+ Asserts that the SnowflakeSqlApiTrigger correctly serializes its
arguments
+ and classpath.
+ """
+ classpath, kwargs = self.TRIGGER.serialize()
+ assert classpath ==
"airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger"
+ assert kwargs == {
+ "poll_interval": POLL_INTERVAL,
+ "query_ids": ["uuid"],
+ "snowflake_conn_id": "test_conn",
+ "token_life_time": LIFETIME,
+ "token_renewal_delta": RENEWAL_DELTA,
+ }
+
+ @pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status")
+
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+ async def test_snowflake_sql_trigger_running(
+ self, mock_get_sql_api_query_status_async, mock_get_query_status
+ ):
+ """Tests that the SnowflakeSqlApiTrigger in running by mocking
get_query_status to true"""
+ mock_get_query_status.return_value = {"status": "running"}
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.triggers.snowflake_trigger.SnowflakeSqlApiTrigger.get_query_status")
+
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+ async def test_snowflake_sql_trigger_completed(
+ self, mock_get_sql_api_query_status_async, mock_get_query_status
+ ):
+ """
+ Test SnowflakeSqlApiTrigger run method with success status and mock
the get_sql_api_query_status
+ result and get_query_status to False.
+ """
+ mock_get_query_status.return_value = {"status": "success",
"statement_handles": ["uuid", "uuid1"]}
+ statement_query_ids = ["uuid", "uuid1"]
+ mock_get_sql_api_query_status_async.return_value = {
+ "message": "Statement executed successfully.",
+ "status": "success",
+ "statement_handles": statement_query_ids,
+ }
+
+ generator = self.TRIGGER.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "success", "statement_query_ids":
statement_query_ids}) == actual
+
+ @pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+ async def test_snowflake_sql_trigger_failure_status(self,
mock_get_sql_api_query_status_async):
+ """Test SnowflakeSqlApiTrigger task is executed and triggered with
failure status."""
+ mock_response = {
+ "status": "error",
+ "message": "An error occurred when executing the statement. Check "
+ "the error code and error message for details",
+ }
+ mock_get_sql_api_query_status_async.return_value = mock_response
+
+ generator = self.TRIGGER.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent(mock_response) == actual
+
+ @pytest.mark.asyncio
+
@mock.patch(f"{MODULE}.hooks.snowflake_sql_api.SnowflakeSqlApiHook.get_sql_api_query_status_async")
+ async def test_snowflake_sql_trigger_exception(self,
mock_get_sql_api_query_status_async):
+ """Tests the SnowflakeSqlApiTrigger does not fire if there is an
exception."""
+ mock_get_sql_api_query_status_async.side_effect = Exception("Test
exception")
+
+ task = [i async for i in self.TRIGGER.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"})
in task