This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9563dc573b add deferrable mode to RedshiftDataOperator (#36586)
9563dc573b is described below
commit 9563dc573bc53b2c84640c88371b62cccdd811ff
Author: Wei Lee <[email protected]>
AuthorDate: Fri Jan 19 02:35:43 2024 +0800
add deferrable mode to RedshiftDataOperator (#36586)
* feat(providers/amazon): add deferrable mode to RedshiftDataOperator
* test(providers/amazon): add test case to RedshiftDataHook async methods
* test(providers/amazon): add test case to RedshiftDataOperator when
deferrable = True
* refactor(providers/amazon): extract comment operator initialization as
deferrable_operator fixture
* refactor(providers/amaozn): rename region as region_name
* feat(providers/amazon): add verify and botocore_config as suggested
* refactor(providers/amazon): use async_conn from aws hook and add missing
await
* feat(providers/amazon): make RedshiftDataTrigger.hook a cached_property
* refactor(providers/amaozn): unify how async and sync version of
check_query_is_finished are implemented
* style(providers/amazon): fix mypy failure
* fix(providers/amazon): fix async_conn call
---
.../providers/amazon/aws/hooks/redshift_data.py | 83 ++++++++---
.../amazon/aws/operators/redshift_data.py | 55 +++++++-
.../providers/amazon/aws/triggers/redshift_data.py | 113 +++++++++++++++
airflow/providers/amazon/provider.yaml | 1 +
.../amazon/aws/hooks/test_redshift_data.py | 61 +++++++-
.../amazon/aws/operators/test_redshift_data.py | 115 +++++++++++++++-
.../amazon/aws/triggers/test_redshift_data.py | 153 +++++++++++++++++++++
7 files changed, 560 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index f7df0fd744..538e5cee96 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -26,6 +26,21 @@ from airflow.providers.amazon.aws.utils import
trim_none_values
if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
+ from mypy_boto3_redshift_data.type_defs import
DescribeStatementResponseTypeDef
+
+FINISHED_STATE = "FINISHED"
+FAILED_STATE = "FAILED"
+ABORTED_STATE = "ABORTED"
+FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
+RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
+
+
+class RedshiftDataQueryFailedError(ValueError):
+ """Raise an error that redshift data query failed."""
+
+
+class RedshiftDataQueryAbortedError(ValueError):
+ """Raise an error that redshift data query was aborted."""
class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
@@ -108,27 +123,40 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
return statement_id
- def wait_for_results(self, statement_id, poll_interval):
+ def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
self.log.info("Polling statement %s", statement_id)
- resp = self.conn.describe_statement(
- Id=statement_id,
- )
- status = resp["Status"]
- if status == "FINISHED":
- num_rows = resp.get("ResultRows")
- if num_rows is not None:
- self.log.info("Processed %s rows", num_rows)
- return status
- elif status in ("FAILED", "ABORTED"):
- raise ValueError(
- f"Statement {statement_id!r} terminated with status
{status}. "
- f"Response details: {pformat(resp)}"
- )
- else:
- self.log.info("Query %s", status)
+ is_finished = self.check_query_is_finished(statement_id)
+ if is_finished:
+ return FINISHED_STATE
+
time.sleep(poll_interval)
+ def check_query_is_finished(self, statement_id: str) -> bool:
+ """Check whether query finished, raise exception is failed."""
+ resp = self.conn.describe_statement(Id=statement_id)
+ return self.parse_statement_resposne(resp)
+
+ def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef)
-> bool:
+ """Parse the response of describe_statement."""
+ status = resp["Status"]
+ if status == FINISHED_STATE:
+ num_rows = resp.get("ResultRows")
+ if num_rows is not None:
+ self.log.info("Processed %s rows", num_rows)
+ return True
+ elif status in FAILURE_STATES:
+ exception_cls = (
+ RedshiftDataQueryFailedError if status == FAILED_STATE else
RedshiftDataQueryAbortedError
+ )
+ raise exception_cls(
+ f"Statement {resp['Id']} terminated with status {status}. "
+ f"Response details: {pformat(resp)}"
+ )
+
+ self.log.info("Query status: %s", status)
+ return False
+
def get_table_primary_key(
self,
table: str,
@@ -201,3 +229,24 @@ class
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
break
return pk_columns or None
+
+ async def is_still_running(self, statement_id: str) -> bool:
+ """Async function to check whether the query is still running.
+
+ :param statement_id: the UUID of the statement
+ """
+ async with self.async_conn as client:
+ desc = await client.describe_statement(Id=statement_id)
+ return desc["Status"] in RUNNING_STATES
+
+ async def check_query_is_finished_async(self, statement_id: str) -> bool:
+ """Async function to check statement is finished.
+
+ It takes statement_id, makes async connection to redshift data to get
the query status
+ by statement_id and returns the query status.
+
+ :param statement_id: the UUID of the statement
+ """
+ async with self.async_conn as client:
+ resp = await client.describe_statement(Id=statement_id)
+ return self.parse_statement_resposne(resp)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py
b/airflow/providers/amazon/aws/operators/redshift_data.py
index b454ad76ec..71ee82069e 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.redshift_data import
RedshiftDataTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
@@ -92,6 +95,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
poll_interval: int = 10,
return_sql_result: bool = False,
workgroup_name: str | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -114,11 +118,17 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
)
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
+ self.deferrable = deferrable
def execute(self, context: Context) -> GetStatementResultResponseTypeDef |
str:
"""Execute a statement against Amazon Redshift."""
self.log.info("Executing statement: %s", self.sql)
+ # Set wait_for_completion to False so that it waits for the status in
the deferred task.
+ wait_for_completion = self.wait_for_completion
+ if self.deferrable and self.wait_for_completion:
+ self.wait_for_completion = False
+
self.statement_id = self.hook.execute_query(
database=self.database,
sql=self.sql,
@@ -129,10 +139,27 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
secret_arn=self.secret_arn,
statement_name=self.statement_name,
with_event=self.with_event,
- wait_for_completion=self.wait_for_completion,
+ wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
)
+ if self.deferrable:
+ is_finished = self.hook.check_query_is_finished(self.statement_id)
+ if not is_finished:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=RedshiftDataTrigger(
+ statement_id=self.statement_id,
+ task_id=self.task_id,
+ poll_interval=self.poll_interval,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ botocore_config=self.botocore_config,
+ ),
+ method_name="execute_complete",
+ )
+
if self.return_sql_result:
result = self.hook.conn.get_statement_result(Id=self.statement_id)
self.log.debug("Statement result: %s", result)
@@ -140,6 +167,30 @@ class
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
else:
return self.statement_id
+ def execute_complete(
+ self, context: Context, event: dict[str, Any] | None = None
+ ) -> GetStatementResultResponseTypeDef | str:
+ if event is None:
+ err_msg = "Trigger error: event is None"
+ self.log.info(err_msg)
+ raise AirflowException(err_msg)
+
+ if event["status"] == "error":
+ msg = f"context: {context}, error message: {event['message']}"
+ raise AirflowException(msg)
+
+ statement_id = event["statement_id"]
+ if not statement_id:
+ raise AirflowException("statement_id should not be empty.")
+
+ self.log.info("%s completed successfully.", self.task_id)
+ if self.return_sql_result:
+ result = self.hook.conn.get_statement_result(Id=statement_id)
+ self.log.debug("Statement result: %s", result)
+ return result
+
+ return statement_id
+
def on_kill(self) -> None:
"""Cancel the submitted redshift query."""
if self.statement_id:
diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py
b/airflow/providers/amazon/aws/triggers/redshift_data.py
new file mode 100644
index 0000000000..2d0ecbc594
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/redshift_data.py
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from functools import cached_property
+from typing import Any, AsyncIterator
+
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+ ABORTED_STATE,
+ FAILED_STATE,
+ RedshiftDataHook,
+ RedshiftDataQueryAbortedError,
+ RedshiftDataQueryFailedError,
+)
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class RedshiftDataTrigger(BaseTrigger):
+ """
+ RedshiftDataTrigger is fired as deferred class with params to run the task
in triggerer.
+
+ :param statement_id: the UUID of the statement
+ :param task_id: task ID of the Dag
+ :param poll_interval: polling period in seconds to check for the status
+ :param aws_conn_id: AWS connection ID for redshift
+ :param region_name: aws region to use
+ """
+
+ def __init__(
+ self,
+ statement_id: str,
+ task_id: str,
+ poll_interval: int,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
+ verify: bool | str | None = None,
+ botocore_config: dict | None = None,
+ ):
+ super().__init__()
+ self.statement_id = statement_id
+ self.task_id = task_id
+ self.poll_interval = poll_interval
+
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.verify = verify
+ self.botocore_config = botocore_config
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes RedshiftDataTrigger arguments and classpath."""
+ return (
+
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
+ {
+ "statement_id": self.statement_id,
+ "task_id": self.task_id,
+ "aws_conn_id": self.aws_conn_id,
+ "poll_interval": self.poll_interval,
+ "region_name": self.region_name,
+ "verify": self.verify,
+ "botocore_config": self.botocore_config,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> RedshiftDataHook:
+ return RedshiftDataHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ try:
+ while await self.hook.is_still_running(self.statement_id):
+ await asyncio.sleep(self.poll_interval)
+
+ is_finished = await
self.hook.check_query_is_finished_async(self.statement_id)
+ if is_finished:
+ response = {"status": "success", "statement_id":
self.statement_id}
+ else:
+ response = {
+ "status": "error",
+ "statement_id": self.statement_id,
+ "message": f"{self.task_id} failed",
+ }
+ yield TriggerEvent(response)
+ except (RedshiftDataQueryFailedError, RedshiftDataQueryAbortedError)
as error:
+ response = {
+ "status": "error",
+ "statement_id": self.statement_id,
+ "message": str(error),
+ "type": FAILED_STATE if isinstance(error,
RedshiftDataQueryFailedError) else ABORTED_STATE,
+ }
+ yield TriggerEvent(response)
+ except Exception as error:
+ yield TriggerEvent({"status": "error", "statement_id":
self.statement_id, "message": str(error)})
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 1b90089db2..bcbb5c18e3 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -621,6 +621,7 @@ triggers:
- integration-name: Amazon Redshift
python-modules:
- airflow.providers.amazon.aws.triggers.redshift_cluster
+ - airflow.providers.amazon.aws.triggers.redshift_data
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.triggers.sagemaker
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index cc174a872c..126585b432 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -22,7 +22,11 @@ from unittest import mock
import pytest
-from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+ RedshiftDataHook,
+ RedshiftDataQueryAbortedError,
+ RedshiftDataQueryFailedError,
+)
SQL = "sql"
DATABASE = "database"
@@ -292,3 +296,58 @@ class TestRedshiftDataHook:
wait_for_completion=True,
)
assert "Processed " not in caplog.text
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "describe_statement_response, expected_result",
+ [
+ ({"Status": "PICKED"}, True),
+ ({"Status": "STARTED"}, True),
+ ({"Status": "SUBMITTED"}, True),
+ ({"Status": "FINISHED"}, False),
+ ({"Status": "FAILED"}, False),
+ ({"Status": "ABORTED"}, False),
+ ],
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+ async def test_is_still_running(self, mock_conn,
describe_statement_response, expected_result):
+ hook = RedshiftDataHook()
+ mock_conn.__aenter__.return_value.describe_statement.return_value =
describe_statement_response
+ response = await hook.is_still_running("uuid")
+ assert response == expected_result
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
+ async def test_check_query_is_finished_async(self, mock_is_still_running,
mock_conn):
+ hook = RedshiftDataHook()
+ mock_is_still_running.return_value = False
+ mock_conn.describe_statement = mock.AsyncMock()
+ mock_conn.__aenter__.return_value.describe_statement.return_value = {
+ "Id": "uuid",
+ "Status": "FINISHED",
+ }
+ is_finished = await
hook.check_query_is_finished_async(statement_id="uuid")
+ assert is_finished is True
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "describe_statement_response, expected_exception",
+ (
+ (
+ {"Id": "uuid", "Status": "FAILED", "QueryString": "select 1",
"Error": "Test error"},
+ RedshiftDataQueryFailedError,
+ ),
+ ({"Id": "uuid", "Status": "ABORTED"},
RedshiftDataQueryAbortedError),
+ ),
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
+ async def test_check_query_is_finished_async_exception(
+ self, mock_is_still_running, mock_conn, describe_statement_response,
expected_exception
+ ):
+ hook = RedshiftDataHook()
+ mock_is_still_running.return_value = False
+ mock_conn.__aenter__.return_value.describe_statement.return_value =
describe_statement_response
+ with pytest.raises(expected_exception):
+ await hook.check_query_is_finished_async(statement_id="uuid")
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 4b921b7142..fa22c98218 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -21,8 +21,9 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.amazon.aws.operators.redshift_data import
RedshiftDataOperator
+from airflow.providers.amazon.aws.triggers.redshift_data import
RedshiftDataTrigger
CONN_ID = "aws_conn_test"
TASK_ID = "task_id"
@@ -31,6 +32,32 @@ DATABASE = "database"
STATEMENT_ID = "statement_id"
[email protected]
+def deferrable_operator():
+ cluster_identifier = "cluster_identifier"
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ poll_interval = 5
+
+ operator = RedshiftDataOperator(
+ aws_conn_id=CONN_ID,
+ task_id=TASK_ID,
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ wait_for_completion=False,
+ poll_interval=poll_interval,
+ deferrable=True,
+ )
+ return operator
+
+
class TestRedshiftDataOperator:
def test_init(self):
op = RedshiftDataOperator(
@@ -202,3 +229,89 @@ class TestRedshiftDataOperator:
mock_conn.get_statement_result.assert_called_once_with(
Id=STATEMENT_ID,
)
+
+
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
+ return_value=True,
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+ def test_execute_finished_before_defer(self, mock_exec_query,
check_query_is_finished, mock_defer):
+ cluster_identifier = "cluster_identifier"
+ workgroup_name = None
+ db_user = "db_user"
+ secret_arn = "secret_arn"
+ statement_name = "statement_name"
+ parameters = [{"name": "id", "value": "1"}]
+ poll_interval = 5
+
+ operator = RedshiftDataOperator(
+ aws_conn_id=CONN_ID,
+ task_id=TASK_ID,
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ wait_for_completion=False,
+ poll_interval=poll_interval,
+ deferrable=True,
+ )
+ operator.execute(None)
+
+ assert not mock_defer.called
+ mock_exec_query.assert_called_once_with(
+ sql=SQL,
+ database=DATABASE,
+ cluster_identifier=cluster_identifier,
+ workgroup_name=workgroup_name,
+ db_user=db_user,
+ secret_arn=secret_arn,
+ statement_name=statement_name,
+ parameters=parameters,
+ with_event=False,
+ wait_for_completion=False,
+ poll_interval=poll_interval,
+ )
+
+ #
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
+ return_value=False,
+ )
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+ def test_execute_defer(self, mock_exec_query, check_query_is_finished,
deferrable_operator):
+ with pytest.raises(TaskDeferred) as exc:
+ deferrable_operator.execute(None)
+
+ assert isinstance(exc.value.trigger, RedshiftDataTrigger)
+
+ def test_execute_complete_failure(self, deferrable_operator):
+ """Tests that an AirflowException is raised in case of error event"""
+ with pytest.raises(AirflowException):
+ deferrable_operator.execute_complete(
+ context=None, event={"status": "error", "message": "test
failure message"}
+ )
+
+ def test_execute_complete_exception(self, deferrable_operator):
+ """Tests that an AirflowException is raised in case of empty event"""
+ with pytest.raises(AirflowException) as exc:
+ deferrable_operator.execute_complete(context=None, event=None)
+ assert exc.value.args[0] == "Did not receive valid event from the
trigerrer"
+
+ def test_execute_complete(self, deferrable_operator):
+ """Asserts that logging occurs as expected"""
+
+ deferrable_operator.statement_id = "uuid"
+
+ with mock.patch.object(deferrable_operator.log, "info") as
mock_log_info:
+ assert (
+ deferrable_operator.execute_complete(
+ context=None,
+ event={"status": "success", "message": "Job completed",
"statement_id": "uuid"},
+ )
+ == "uuid"
+ )
+ mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py
b/tests/providers/amazon/aws/triggers/test_redshift_data.py
new file mode 100644
index 0000000000..49c0862af2
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+ ABORTED_STATE,
+ FAILED_STATE,
+ RedshiftDataQueryAbortedError,
+ RedshiftDataQueryFailedError,
+)
+from airflow.providers.amazon.aws.triggers.redshift_data import
RedshiftDataTrigger
+from airflow.triggers.base import TriggerEvent
+
+TEST_CONN_ID = "aws_default"
+TEST_TASK_ID = "123"
+POLL_INTERVAL = 4.0
+
+
+class TestRedshiftDataTrigger:
+ def test_redshift_data_trigger_serialization(self):
+ """
+ Asserts that the RedshiftDataTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = RedshiftDataTrigger(
+ statement_id=[],
+ task_id=TEST_TASK_ID,
+ aws_conn_id=TEST_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger"
+ assert kwargs == {
+ "statement_id": [],
+ "task_id": TEST_TASK_ID,
+ "poll_interval": POLL_INTERVAL,
+ "aws_conn_id": TEST_CONN_ID,
+ "region_name": None,
+ "botocore_config": None,
+ "verify": None,
+ }
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "return_value, response",
+ [
+ (
+ True,
+ TriggerEvent({"status": "success", "statement_id": "uuid"}),
+ ),
+ (
+ False,
+ TriggerEvent(
+ {"status": "error", "message": f"{TEST_TASK_ID} failed",
"statement_id": "uuid"}
+ ),
+ ),
+ ],
+ )
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async"
+ )
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running",
+ return_value=False,
+ )
+ async def test_redshift_data_trigger_run(
+ self, mocked_is_still_running, mock_check_query_is_finised_async,
return_value, response
+ ):
+ """
+ Tests that RedshiftDataTrigger only fires once the query execution
reaches a successful state.
+ """
+ mock_check_query_is_finised_async.return_value = return_value
+ trigger = RedshiftDataTrigger(
+ statement_id="uuid",
+ task_id=TEST_TASK_ID,
+ poll_interval=POLL_INTERVAL,
+ aws_conn_id=TEST_CONN_ID,
+ )
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert response == actual
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "raised_exception, expected_response",
+ [
+ (
+ RedshiftDataQueryFailedError("Failed"),
+ {
+ "status": "error",
+ "statement_id": "uuid",
+ "message": "Failed",
+ "type": FAILED_STATE,
+ },
+ ),
+ (
+ RedshiftDataQueryAbortedError("Aborted"),
+ {
+ "status": "error",
+ "statement_id": "uuid",
+ "message": "Aborted",
+ "type": ABORTED_STATE,
+ },
+ ),
+ (
+ Exception(f"{TEST_TASK_ID} failed"),
+ {"status": "error", "statement_id": "uuid", "message":
f"{TEST_TASK_ID} failed"},
+ ),
+ ],
+ )
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async"
+ )
+ @mock.patch(
+
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running",
+ return_value=False,
+ )
+ async def test_redshift_data_trigger_exception(
+ self, mocked_is_still_running, mock_check_query_is_finised_async,
raised_exception, expected_response
+ ):
+ """
+ Test that RedshiftDataTrigger fires the correct event in case of an
error.
+ """
+ mock_check_query_is_finised_async.side_effect = raised_exception
+
+ trigger = RedshiftDataTrigger(
+ statement_id="uuid",
+ task_id=TEST_TASK_ID,
+ poll_interval=POLL_INTERVAL,
+ aws_conn_id=TEST_CONN_ID,
+ )
+ task = [i async for i in trigger.run()]
+ assert len(task) == 1
+ assert TriggerEvent(expected_response) in task