This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 9eab3e199e Use base aws classes in Amazon QuickSight Operators/Sensors (#36776) 9eab3e199e is described below commit 9eab3e199ecfcaca2c39cfcf66ff4d7fe83c69ef Author: Andrey Anshin <andrey.ans...@taragol.is> AuthorDate: Mon Jan 15 03:15:16 2024 +0400 Use base aws classes in Amazon QuickSight Operators/Sensors (#36776) --- airflow/providers/amazon/aws/hooks/base_aws.py | 14 ++ airflow/providers/amazon/aws/hooks/quicksight.py | 51 +++-- .../providers/amazon/aws/operators/quicksight.py | 41 ++-- airflow/providers/amazon/aws/sensors/quicksight.py | 58 +++-- .../operators/quicksight.rst | 5 + tests/providers/amazon/aws/hooks/test_base_aws.py | 4 + .../providers/amazon/aws/hooks/test_quicksight.py | 245 +++++++++++++++------ .../amazon/aws/operators/test_quicksight.py | 42 +++- .../amazon/aws/sensors/test_quicksight.py | 133 ++++++----- 9 files changed, 402 insertions(+), 191 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index d6e0762a1a..635a874e26 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -629,6 +629,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]): """Verify or not SSL certificates boto3 client/resource read-only property.""" return self.conn_config.verify + @cached_property + def account_id(self) -> str: + """Return associated AWS Account ID.""" + return ( + self.get_session(region_name=self.region_name) + .client( + service_name="sts", + endpoint_url=self.conn_config.get_service_endpoint_url("sts"), + config=self.config, + verify=self.verify, + ) + .get_caller_identity()["Account"] + ) + def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session: """Get the underlying boto3.session.Session(region_name=region_name).""" return SessionFactory( diff --git a/airflow/providers/amazon/aws/hooks/quicksight.py b/airflow/providers/amazon/aws/hooks/quicksight.py index 6ee7c5bfd4..1106a793c1 100644 --- a/airflow/providers/amazon/aws/hooks/quicksight.py +++ b/airflow/providers/amazon/aws/hooks/quicksight.py @@ -18,13 +18,13 @@ from __future__ import annotations import time +import warnings from functools import cached_property from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook -from airflow.providers.amazon.aws.hooks.sts import StsHook class QuickSightHook(AwsBaseHook): @@ -46,10 +46,6 @@ class QuickSightHook(AwsBaseHook): def __init__(self, *args, **kwargs): super().__init__(client_type="quicksight", *args, **kwargs) - @cached_property - def sts_hook(self): - return StsHook(aws_conn_id=self.aws_conn_id) - def create_ingestion( self, data_set_id: str, @@ -57,6 +53,7 @@ class QuickSightHook(AwsBaseHook): ingestion_type: str, wait_for_completion: bool = True, check_interval: int = 30, + aws_account_id: str | None = None, ) -> dict: """ Create and start a new SPICE ingestion for a dataset; refresh the SPICE datasets. @@ -66,18 +63,18 @@ class QuickSightHook(AwsBaseHook): :param data_set_id: ID of the dataset used in the ingestion. :param ingestion_id: ID for the ingestion. - :param ingestion_type: Type of ingestion . "INCREMENTAL_REFRESH"|"FULL_REFRESH" + :param ingestion_type: Type of ingestion: "INCREMENTAL_REFRESH"|"FULL_REFRESH" :param wait_for_completion: if the program should keep running until job finishes :param check_interval: the time interval in seconds which the operator will check the status of QuickSight Ingestion + :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID. :return: Returns descriptive information about the created data ingestion having Ingestion ARN, HTTP status, ingestion ID and ingestion status. """ + aws_account_id = aws_account_id or self.account_id self.log.info("Creating QuickSight Ingestion for data set id %s.", data_set_id) - quicksight_client = self.get_conn() try: - aws_account_id = self.sts_hook.get_account_number() - create_ingestion_response = quicksight_client.create_ingestion( + create_ingestion_response = self.conn.create_ingestion( DataSetId=data_set_id, IngestionId=ingestion_id, IngestionType=ingestion_type, @@ -97,20 +94,21 @@ class QuickSightHook(AwsBaseHook): self.log.error("Failed to run Amazon QuickSight create_ingestion API, error: %s", general_error) raise - def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> str: + def get_status(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> str: """ Get the current status of QuickSight Create Ingestion API. .. seealso:: - :external+boto3:py:meth:`QuickSight.Client.describe_ingestion` - :param aws_account_id: An AWS Account ID + :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID. :param data_set_id: QuickSight Data Set ID :param ingestion_id: QuickSight Ingestion ID :return: An QuickSight Ingestion Status """ + aws_account_id = aws_account_id or self.account_id try: - describe_ingestion_response = self.get_conn().describe_ingestion( + describe_ingestion_response = self.conn.describe_ingestion( AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id ) return describe_ingestion_response["Ingestion"]["IngestionStatus"] @@ -119,17 +117,19 @@ class QuickSightHook(AwsBaseHook): except ClientError as e: raise AirflowException(f"AWS request failed: {e}") - def get_error_info(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> dict | None: + def get_error_info(self, aws_account_id: str | None, data_set_id: str, ingestion_id: str) -> dict | None: """ Get info about the error if any. - :param aws_account_id: An AWS Account ID + :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID. :param data_set_id: QuickSight Data Set ID :param ingestion_id: QuickSight Ingestion ID :return: Error info dict containing the error type (key 'Type') and message (key 'Message') if available. Else, returns None. """ - describe_ingestion_response = self.get_conn().describe_ingestion( + aws_account_id = aws_account_id or self.account_id + + describe_ingestion_response = self.conn.describe_ingestion( AwsAccountId=aws_account_id, DataSetId=data_set_id, IngestionId=ingestion_id ) # using .get() to get None if the key is not present, instead of an exception. @@ -137,7 +137,7 @@ class QuickSightHook(AwsBaseHook): def wait_for_state( self, - aws_account_id: str, + aws_account_id: str | None, data_set_id: str, ingestion_id: str, target_state: set, @@ -146,7 +146,7 @@ class QuickSightHook(AwsBaseHook): """ Check status of a QuickSight Create Ingestion API. - :param aws_account_id: An AWS Account ID + :param aws_account_id: An AWS Account ID, if set to ``None`` then use associated AWS Account ID. :param data_set_id: QuickSight Data Set ID :param ingestion_id: QuickSight Ingestion ID :param target_state: Describes the QuickSight Job's Target State @@ -154,6 +154,8 @@ class QuickSightHook(AwsBaseHook): will check the status of QuickSight Ingestion :return: response of describe_ingestion call after Ingestion is done """ + aws_account_id = aws_account_id or self.account_id + while True: status = self.get_status(aws_account_id, data_set_id, ingestion_id) self.log.info("Current status is %s", status) @@ -168,3 +170,16 @@ class QuickSightHook(AwsBaseHook): self.log.info("QuickSight Ingestion completed") return status + + @cached_property + def sts_hook(self): + warnings.warn( + f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. " + "This property used for obtain AWS Account ID, " + f"please consider to use `{type(self).__name__}.account_id` instead", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + from airflow.providers.amazon.aws.hooks.sts import StsHook + + return StsHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/operators/quicksight.py b/airflow/providers/amazon/aws/operators/quicksight.py index 4268374117..9555e0d63a 100644 --- a/airflow/providers/amazon/aws/operators/quicksight.py +++ b/airflow/providers/amazon/aws/operators/quicksight.py @@ -18,16 +18,15 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: from airflow.utils.context import Context -DEFAULT_CONN_ID = "aws_default" - -class QuickSightCreateIngestionOperator(BaseOperator): +class QuickSightCreateIngestionOperator(AwsBaseOperator[QuickSightHook]): """ Creates and starts a new SPICE ingestion for a dataset; also helps to Refresh existing SPICE datasets. @@ -43,23 +42,25 @@ class QuickSightCreateIngestionOperator(BaseOperator): that the operation waits to check the status of the Amazon QuickSight Ingestion. :param check_interval: if wait is set to be true, this is the time interval in seconds which the operator will check the status of the Amazon QuickSight Ingestion - :param aws_conn_id: The Airflow connection used for AWS credentials. (templated) - If this is None or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or - empty, then the default boto3 configuration would be used (and must be - maintained on each worker node). - :param region: Which AWS region the connection should use. (templated) - If this is None or empty then the default boto3 behaviour is used. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ( + aws_hook_class = QuickSightHook + template_fields: Sequence[str] = aws_template_fields( "data_set_id", "ingestion_id", "ingestion_type", "wait_for_completion", "check_interval", - "aws_conn_id", - "region", ) ui_color = "#ffd700" @@ -70,26 +71,18 @@ class QuickSightCreateIngestionOperator(BaseOperator): ingestion_type: str = "FULL_REFRESH", wait_for_completion: bool = True, check_interval: int = 30, - aws_conn_id: str = DEFAULT_CONN_ID, - region: str | None = None, **kwargs, ): + super().__init__(**kwargs) self.data_set_id = data_set_id self.ingestion_id = ingestion_id self.ingestion_type = ingestion_type self.wait_for_completion = wait_for_completion self.check_interval = check_interval - self.aws_conn_id = aws_conn_id - self.region = region - super().__init__(**kwargs) def execute(self, context: Context): - hook = QuickSightHook( - aws_conn_id=self.aws_conn_id, - region_name=self.region, - ) self.log.info("Running the Amazon QuickSight SPICE Ingestion on Dataset ID: %s", self.data_set_id) - return hook.create_ingestion( + return self.hook.create_ingestion( data_set_id=self.data_set_id, ingestion_id=self.ingestion_id, ingestion_type=self.ingestion_type, diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py b/airflow/providers/amazon/aws/sensors/quicksight.py index fc90ecbe45..ebd8310fe4 100644 --- a/airflow/providers/amazon/aws/sensors/quicksight.py +++ b/airflow/providers/amazon/aws/sensors/quicksight.py @@ -17,19 +17,19 @@ # under the License. from __future__ import annotations +import warnings from functools import cached_property from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook -from airflow.providers.amazon.aws.hooks.sts import StsHook -from airflow.sensors.base import BaseSensorOperator +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor if TYPE_CHECKING: from airflow.utils.context import Context -class QuickSightSensor(BaseSensorOperator): +class QuickSightSensor(AwsBaseSensor[QuickSightHook]): """ Watches for the status of an Amazon QuickSight Ingestion. @@ -39,27 +39,25 @@ class QuickSightSensor(BaseSensorOperator): :param data_set_id: ID of the dataset used in the ingestion. :param ingestion_id: ID for the ingestion. - :param aws_conn_id: The Airflow connection used for AWS credentials. (templated) - If this is None or empty then the default boto3 behaviour is used. If - running Airflow in a distributed manner and aws_conn_id is None or - empty, then the default boto3 configuration would be used (and must be - maintained on each worker node). + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ + aws_hook_class = QuickSightHook template_fields: Sequence[str] = ("data_set_id", "ingestion_id", "aws_conn_id") - def __init__( - self, - *, - data_set_id: str, - ingestion_id: str, - aws_conn_id: str = "aws_default", - **kwargs, - ) -> None: + def __init__(self, *, data_set_id: str, ingestion_id: str, **kwargs): super().__init__(**kwargs) self.data_set_id = data_set_id self.ingestion_id = ingestion_id - self.aws_conn_id = aws_conn_id self.success_status = "COMPLETED" self.errored_statuses = ("FAILED", "CANCELLED") @@ -71,13 +69,10 @@ class QuickSightSensor(BaseSensorOperator): :return: True if it COMPLETED and False if not. """ self.log.info("Poking for Amazon QuickSight Ingestion ID: %s", self.ingestion_id) - aws_account_id = self.sts_hook.get_account_number() - quicksight_ingestion_state = self.quicksight_hook.get_status( - aws_account_id, self.data_set_id, self.ingestion_id - ) + quicksight_ingestion_state = self.hook.get_status(None, self.data_set_id, self.ingestion_id) self.log.info("QuickSight Status: %s", quicksight_ingestion_state) if quicksight_ingestion_state in self.errored_statuses: - error = self.quicksight_hook.get_error_info(aws_account_id, self.data_set_id, self.ingestion_id) + error = self.hook.get_error_info(None, self.data_set_id, self.ingestion_id) message = f"The QuickSight Ingestion failed. Error info: {error}" if self.soft_fail: raise AirflowSkipException(message) @@ -86,8 +81,23 @@ class QuickSightSensor(BaseSensorOperator): @cached_property def quicksight_hook(self): - return QuickSightHook(aws_conn_id=self.aws_conn_id) + warnings.warn( + f"`{type(self).__name__}.quicksight_hook` property is deprecated, " + f"please use `{type(self).__name__}.hook` property instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.hook @cached_property def sts_hook(self): + warnings.warn( + f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. " + "This property used for obtain AWS Account ID, " + f"please consider to use `{type(self).__name__}.hook.account_id` instead", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + from airflow.providers.amazon.aws.hooks.sts import StsHook + return StsHook(aws_conn_id=self.aws_conn_id) diff --git a/docs/apache-airflow-providers-amazon/operators/quicksight.rst b/docs/apache-airflow-providers-amazon/operators/quicksight.rst index cbca98d7d5..9cc0abe337 100644 --- a/docs/apache-airflow-providers-amazon/operators/quicksight.rst +++ b/docs/apache-airflow-providers-amazon/operators/quicksight.rst @@ -30,6 +30,11 @@ Prerequisite Tasks .. include:: ../_partials/prerequisite_tasks.rst +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + Operators --------- diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index ba94048421..c87aaa98fd 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -1031,6 +1031,10 @@ class TestAwsBaseHook: assert mock_mask_secret.mock_calls == expected_calls assert credentials == expected_credentials + @mock_sts + def test_account_id(self): + assert AwsBaseHook(aws_conn_id=None).account_id == DEFAULT_ACCOUNT_ID + class ThrowErrorUntilCount: """Holds counter state for invoking a method several times in a row.""" diff --git a/tests/providers/amazon/aws/hooks/test_quicksight.py b/tests/providers/amazon/aws/hooks/test_quicksight.py index 9c8ef16ce8..6a7795843b 100644 --- a/tests/providers/amazon/aws/hooks/test_quicksight.py +++ b/tests/providers/amazon/aws/hooks/test_quicksight.py @@ -22,19 +22,15 @@ from unittest import mock import pytest from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook -from airflow.providers.amazon.aws.hooks.sts import StsHook - -AWS_ACCOUNT_ID = "123456789012" +DEFAULT_AWS_ACCOUNT_ID = "123456789012" MOCK_DATA = { "DataSetId": "DemoDataSet", "IngestionId": "DemoDataSet_Ingestion", "IngestionType": "INCREMENTAL_REFRESH", - "AwsAccountId": AWS_ACCOUNT_ID, } - MOCK_CREATE_INGESTION_RESPONSE = { "Status": 201, "Arn": "arn:aws:quicksight:us-east-1:123456789012:dataset/DemoDataSet/ingestion/DemoDataSet3_Ingestion", @@ -42,7 +38,6 @@ MOCK_CREATE_INGESTION_RESPONSE = { "IngestionStatus": "INITIALIZED", "RequestId": "fc1f7eea-1327-41d6-9af7-c12f097ed343", } - MOCK_DESCRIBE_INGESTION_SUCCESS = { "Status": 200, "Ingestion": { @@ -59,7 +54,6 @@ MOCK_DESCRIBE_INGESTION_SUCCESS = { }, "RequestId": "DemoDataSet_Ingestion_Request_ID", } - MOCK_DESCRIBE_INGESTION_FAILURE = { "Status": 403, "Ingestion": { @@ -76,6 +70,23 @@ MOCK_DESCRIBE_INGESTION_FAILURE = { }, "RequestId": "DemoDataSet_Ingestion_Request_ID", } +ACCOUNT_TEST_CASES = [ + pytest.param(None, DEFAULT_AWS_ACCOUNT_ID, id="default-account-id"), + pytest.param("777777777777", "777777777777", id="custom-account-id"), +] + + +@pytest.fixture +def mocked_account_id(): + with mock.patch.object(QuickSightHook, "account_id", new_callable=mock.PropertyMock) as m: + m.return_value = DEFAULT_AWS_ACCOUNT_ID + yield m + + +@pytest.fixture +def mocked_client(): + with mock.patch.object(QuickSightHook, "conn") as m: + yield m class TestQuicksight: @@ -83,70 +94,174 @@ class TestQuicksight: hook = QuickSightHook(aws_conn_id="aws_default", region_name="us-east-1") assert hook.conn is not None - @mock.patch.object(QuickSightHook, "get_conn") - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_create_ingestion(self, mock_get_account_number, sts_conn, mock_conn): - mock_conn.return_value.create_ingestion.return_value = MOCK_CREATE_INGESTION_RESPONSE - mock_get_account_number.return_value = AWS_ACCOUNT_ID - quicksight_hook = QuickSightHook(aws_conn_id="aws_default", region_name="us-east-1") - result = quicksight_hook.create_ingestion( - data_set_id="DemoDataSet", - ingestion_id="DemoDataSet_Ingestion", - ingestion_type="INCREMENTAL_REFRESH", + @pytest.mark.parametrize( + "response, expected_status", + [ + pytest.param(MOCK_DESCRIBE_INGESTION_SUCCESS, "COMPLETED", id="completed"), + pytest.param(MOCK_DESCRIBE_INGESTION_FAILURE, "Failed", id="failed"), + ], + ) + @pytest.mark.parametrize("aws_account_id, expected_account_id", ACCOUNT_TEST_CASES) + def test_get_job_status( + self, response, expected_status, aws_account_id, expected_account_id, mocked_account_id, mocked_client + ): + """Test get job status.""" + mocked_client.describe_ingestion.return_value = response + + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + assert ( + hook.get_status( + data_set_id="DemoDataSet", + ingestion_id="DemoDataSet_Ingestion", + aws_account_id=aws_account_id, + ) + == expected_status + ) + mocked_client.describe_ingestion.assert_called_with( + AwsAccountId=expected_account_id, + DataSetId="DemoDataSet", + IngestionId="DemoDataSet_Ingestion", + ) + + @pytest.mark.parametrize( + "exception, error_match", + [ + pytest.param(KeyError("Foo"), "Could not get status", id="key-error"), + pytest.param( + ClientError(error_response={}, operation_name="fake"), + "AWS request failed", + id="botocore-client", + ), + ], + ) + def test_get_job_status_exception(self, exception, error_match, mocked_client, mocked_account_id): + mocked_client.describe_ingestion.side_effect = exception + + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + with pytest.raises(AirflowException, match=error_match): + assert hook.get_status( + data_set_id="DemoDataSet", + ingestion_id="DemoDataSet_Ingestion", + aws_account_id=None, + ) + + @pytest.mark.parametrize( + "error_info", + [ + pytest.param({"foo": "bar"}, id="error-info-exists"), + pytest.param(None, id="error-info-not-exists"), + ], + ) + @pytest.mark.parametrize("aws_account_id, expected_account_id", ACCOUNT_TEST_CASES) + def test_get_error_info( + self, error_info, aws_account_id, expected_account_id, mocked_client, mocked_account_id + ): + mocked_response = {"Ingestion": {}} + if error_info: + mocked_response["Ingestion"]["ErrorInfo"] = error_info + mocked_client.describe_ingestion.return_value = mocked_response + + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + assert ( + hook.get_error_info( + data_set_id="DemoDataSet", ingestion_id="DemoDataSet_Ingestion", aws_account_id=None + ) + == error_info ) - expected_call_params = MOCK_DATA - mock_conn.return_value.create_ingestion.assert_called_with(**expected_call_params) - assert result == MOCK_CREATE_INGESTION_RESPONSE - @mock.patch.object(QuickSightHook, "get_conn") + @mock.patch.object(QuickSightHook, "get_status", return_value="FAILED") + @mock.patch.object(QuickSightHook, "get_error_info") + @pytest.mark.parametrize("aws_account_id, expected_account_id", ACCOUNT_TEST_CASES) + def test_wait_for_state_failure( + self, + mocked_get_error_info, + mocked_get_status, + aws_account_id, + expected_account_id, + mocked_client, + mocked_account_id, + ): + mocked_get_error_info.return_value = "Something Bad Happen" + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + with pytest.raises(AirflowException, match="Error info: Something Bad Happen"): + hook.wait_for_state( + aws_account_id, "data_set_id", "ingestion_id", target_state={"COMPLETED"}, check_interval=0 + ) + mocked_get_status.assert_called_with(expected_account_id, "data_set_id", "ingestion_id") + mocked_get_error_info.assert_called_with(expected_account_id, "data_set_id", "ingestion_id") + + @mock.patch.object(QuickSightHook, "get_status", return_value="CANCELLED") + def test_wait_for_state_canceled(self, _): + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + with pytest.raises(AirflowException, match="The Amazon QuickSight SPICE ingestion cancelled"): + hook.wait_for_state( + "aws_account_id", "data_set_id", "ingestion_id", target_state={"COMPLETED"}, check_interval=0 + ) + @mock.patch.object(QuickSightHook, "get_status") - def test_fast_failing_ingestion(self, mock_get_status, mock_conn): - quicksight_hook = QuickSightHook(aws_conn_id="aws_default", region_name="us-east-1") - mock_get_status.return_value = "FAILED" - with pytest.raises(AirflowException): - quicksight_hook.wait_for_state( - "account_id", "data_set_id", "ingestion_id", target_state={"COMPLETED"}, check_interval=1 + def test_wait_for_state_completed(self, mocked_get_status): + mocked_get_status.side_effect = ["INITIALIZED", "QUEUED", "RUNNING", "COMPLETED"] + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + assert ( + hook.wait_for_state( + "aws_account_id", "data_set_id", "ingestion_id", target_state={"COMPLETED"}, check_interval=0 ) + == "COMPLETED" + ) + assert mocked_get_status.call_count == 4 + + @pytest.mark.parametrize( + "wait_for_completion", [pytest.param(True, id="wait"), pytest.param(False, id="no-wait")] + ) + @pytest.mark.parametrize("aws_account_id, expected_account_id", ACCOUNT_TEST_CASES) + def test_create_ingestion( + self, wait_for_completion, aws_account_id, expected_account_id, mocked_account_id, mocked_client + ): + mocked_client.create_ingestion.return_value = MOCK_CREATE_INGESTION_RESPONSE + + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + with mock.patch.object(QuickSightHook, "wait_for_state") as mocked_wait_for_state: + assert ( + hook.create_ingestion( + data_set_id="DemoDataSet", + ingestion_id="DemoDataSet_Ingestion", + ingestion_type="INCREMENTAL_REFRESH", + aws_account_id=aws_account_id, + wait_for_completion=wait_for_completion, + check_interval=0, + ) + == MOCK_CREATE_INGESTION_RESPONSE + ) + if wait_for_completion: + mocked_wait_for_state.assert_called_once_with( + aws_account_id=expected_account_id, + data_set_id="DemoDataSet", + ingestion_id="DemoDataSet_Ingestion", + target_state={"COMPLETED"}, + check_interval=0, + ) + else: + mocked_wait_for_state.assert_not_called() - @mock.patch.object(StsHook, "get_conn") - @mock.patch.object(StsHook, "get_account_number") - def test_create_ingestion_exception(self, mock_get_account_number, sts_conn): - mock_get_account_number.return_value = AWS_ACCOUNT_ID - hook = QuickSightHook(aws_conn_id="aws_default") - with pytest.raises(ClientError) as raised_exception: + mocked_client.create_ingestion.assert_called_with(AwsAccountId=expected_account_id, **MOCK_DATA) + + def test_create_ingestion_exception(self, mocked_account_id, mocked_client, caplog): + mocked_client.create_ingestion.side_effect = ValueError("Fake Error") + hook = QuickSightHook(aws_conn_id=None) + with pytest.raises(ValueError, match="Fake Error"): hook.create_ingestion( data_set_id="DemoDataSet", ingestion_id="DemoDataSet_Ingestion", ingestion_type="INCREMENTAL_REFRESH", ) - ex = raised_exception.value - assert ex.operation_name == "CreateIngestion" - - @mock.patch.object(QuickSightHook, "get_conn") - def test_get_job_status(self, mock_conn): - """ - Test get job status - """ - mock_conn.return_value.describe_ingestion.return_value = MOCK_DESCRIBE_INGESTION_SUCCESS - quicksight_hook = QuickSightHook(aws_conn_id="aws_default", region_name="us-east-1") - result = quicksight_hook.get_status( - data_set_id="DemoDataSet", - ingestion_id="DemoDataSet_Ingestion", - aws_account_id="123456789012", - ) - assert result == "COMPLETED" - - @mock.patch.object(QuickSightHook, "get_conn") - def test_get_job_status_failed(self, mock_conn): - """ - Test get job status - """ - mock_conn.return_value.describe_ingestion.return_value = MOCK_DESCRIBE_INGESTION_FAILURE - quicksight_hook = QuickSightHook(aws_conn_id="aws_default", region_name="us-east-1") - result = quicksight_hook.get_status( - data_set_id="DemoDataSet", - ingestion_id="DemoDataSet_Ingestion", - aws_account_id="123456789012", - ) - assert result == "Failed" + assert "create_ingestion API, error: Fake Error" in caplog.text + + def test_deprecated_properties(self): + hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1") + with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as mocked_class, pytest.warns( + AirflowProviderDeprecationWarning, match="consider to use `.*account_id` instead" + ): + mocked_sts_hook = mock.MagicMock(name="FakeStsHook") + mocked_class.return_value = mocked_sts_hook + assert hook.sts_hook is mocked_sts_hook + mocked_class.assert_called_once_with(aws_conn_id=None) diff --git a/tests/providers/amazon/aws/operators/test_quicksight.py b/tests/providers/amazon/aws/operators/test_quicksight.py index 2b7b0dc35f..fd30426293 100644 --- a/tests/providers/amazon/aws/operators/test_quicksight.py +++ b/tests/providers/amazon/aws/operators/test_quicksight.py @@ -38,17 +38,41 @@ MOCK_RESPONSE = { class TestQuickSightCreateIngestionOperator: def setup_method(self): - self.quicksight = QuickSightCreateIngestionOperator( - task_id="test_quicksight_operator", - data_set_id=DATA_SET_ID, - ingestion_id=INGESTION_ID, + self.default_op_kwargs = { + "task_id": "quicksight_create", + "aws_conn_id": None, + "data_set_id": DATA_SET_ID, + "ingestion_id": INGESTION_ID, + } + + def test_init(self): + self.default_op_kwargs.pop("aws_conn_id", None) + + op = QuickSightCreateIngestionOperator( + **self.default_op_kwargs, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="cn-north-1", + verify=False, + botocore_config={"read_timeout": 42}, ) + assert op.hook.client_type == "quicksight" + assert op.hook.resource_type is None + assert op.hook.aws_conn_id == "fake-conn-id" + assert op.hook._region_name == "cn-north-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + op = QuickSightCreateIngestionOperator(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None - @mock.patch.object(QuickSightHook, "get_conn") - @mock.patch.object(QuickSightHook, "create_ingestion") - def test_execute(self, mock_create_ingestion, mock_client): - mock_create_ingestion.return_value = MOCK_RESPONSE - self.quicksight.execute(None) + @mock.patch.object(QuickSightHook, "create_ingestion", return_value=MOCK_RESPONSE) + def test_execute(self, mock_create_ingestion): + QuickSightCreateIngestionOperator(**self.default_op_kwargs).execute({}) mock_create_ingestion.assert_called_once_with( data_set_id=DATA_SET_ID, ingestion_id=INGESTION_ID, diff --git a/tests/providers/amazon/aws/sensors/test_quicksight.py b/tests/providers/amazon/aws/sensors/test_quicksight.py index ba3ce83789..bef78d072d 100644 --- a/tests/providers/amazon/aws/sensors/test_quicksight.py +++ b/tests/providers/amazon/aws/sensors/test_quicksight.py @@ -20,10 +20,8 @@ from __future__ import annotations from unittest import mock import pytest -from moto import mock_sts -from moto.core import DEFAULT_ACCOUNT_ID -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor @@ -31,58 +29,91 @@ DATA_SET_ID = "DemoDataSet" INGESTION_ID = "DemoDataSet_Ingestion" +@pytest.fixture +def mocked_get_status(): + with mock.patch.object(QuickSightHook, "get_status") as m: + yield m + + +@pytest.fixture +def mocked_get_error_info(): + with mock.patch.object(QuickSightHook, "get_error_info") as m: + yield m + + class TestQuickSightSensor: def setup_method(self): - self.sensor = QuickSightSensor( - task_id="test_quicksight_sensor", - aws_conn_id="aws_default", - data_set_id="DemoDataSet", - ingestion_id="DemoDataSet_Ingestion", + self.default_op_kwargs = { + "task_id": "quicksight_sensor", + "aws_conn_id": None, + "data_set_id": DATA_SET_ID, + "ingestion_id": INGESTION_ID, + } + + def test_init(self): + self.default_op_kwargs.pop("aws_conn_id", None) + + sensor = QuickSightSensor( + **self.default_op_kwargs, + # Generic hooks parameters + aws_conn_id="fake-conn-id", + region_name="ca-west-1", + verify=True, + botocore_config={"read_timeout": 42}, ) + assert sensor.hook.client_type == "quicksight" + assert sensor.hook.resource_type is None + assert sensor.hook.aws_conn_id == "fake-conn-id" + assert sensor.hook._region_name == "ca-west-1" + assert sensor.hook._verify is True + assert sensor.hook._config is not None + assert sensor.hook._config.read_timeout == 42 + + sensor = QuickSightSensor(**self.default_op_kwargs) + assert sensor.hook.aws_conn_id == "aws_default" + assert sensor.hook._region_name is None + assert sensor.hook._verify is None + assert sensor.hook._config is None + + @pytest.mark.parametrize("status", ["COMPLETED"]) + def test_poke_completed(self, status, mocked_get_status): + mocked_get_status.return_value = status + assert QuickSightSensor(**self.default_op_kwargs).poke({}) is True + mocked_get_status.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) - @mock_sts - @mock.patch.object(QuickSightHook, "get_status") - def test_poke_success(self, mock_get_status): - mock_get_status.return_value = "COMPLETED" - assert self.sensor.poke({}) is True - mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) - - @mock_sts - @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(QuickSightHook, "get_error_info") - def test_poke_cancelled(self, _, mock_get_status): - mock_get_status.return_value = "CANCELLED" - with pytest.raises(AirflowException): - self.sensor.poke({}) - mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) - - @mock_sts - @mock.patch.object(QuickSightHook, "get_status") - @mock.patch.object(QuickSightHook, "get_error_info") - def test_poke_failed(self, _, mock_get_status): - mock_get_status.return_value = "FAILED" - with pytest.raises(AirflowException): - self.sensor.poke({}) - mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) - - @mock_sts - @mock.patch.object(QuickSightHook, "get_status") - def test_poke_initialized(self, mock_get_status): - mock_get_status.return_value = "INITIALIZED" - assert self.sensor.poke({}) is False - mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, DATA_SET_ID, INGESTION_ID) + @pytest.mark.parametrize("status", ["INITIALIZED"]) + def test_poke_not_completed(self, status, mocked_get_status): + mocked_get_status.return_value = status + assert QuickSightSensor(**self.default_op_kwargs).poke({}) is False + mocked_get_status.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) + @pytest.mark.parametrize("status", ["FAILED", "CANCELLED"]) @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) + "soft_fail, expected_exception", + [ + pytest.param(True, AirflowSkipException, id="soft-fail"), + pytest.param(False, AirflowException, id="non-soft-fail"), + ], ) - @mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook.get_account_number") - @mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_status") - @mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_error_info") - def test_fail_poke(self, get_error_info, get_status, _, soft_fail, expected_exception): - self.sensor.soft_fail = soft_fail - error = "expected error" - message = f"The QuickSight Ingestion failed. Error info: {error}" - with pytest.raises(expected_exception, match=message): - get_status.return_value = "FAILED" - get_error_info.return_value = message - self.sensor.poke(context={}) + def test_poke_terminated_status( + self, status, soft_fail, expected_exception, mocked_get_status, mocked_get_error_info + ): + mocked_get_status.return_value = status + mocked_get_error_info.return_value = "something bad happen" + with pytest.raises(expected_exception, match="Error info: something bad happen"): + QuickSightSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({}) + mocked_get_status.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) + mocked_get_error_info.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) + + def test_deprecated_properties(self): + sensor = QuickSightSensor(**self.default_op_kwargs) + with pytest.warns(AirflowProviderDeprecationWarning, match="please use `.*hook` property instead"): + assert sensor.quicksight_hook is sensor.hook + + with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as mocked_class, pytest.warns( + AirflowProviderDeprecationWarning, match="consider to use `.*hook\.account_id` instead" + ): + mocked_sts_hook = mock.MagicMock(name="FakeStsHook") + mocked_class.return_value = mocked_sts_hook + assert sensor.sts_hook is mocked_sts_hook + mocked_class.assert_called_once_with(aws_conn_id=None)