This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9eab3e199e Use base aws classes in Amazon QuickSight Operators/Sensors 
(#36776)
9eab3e199e is described below

commit 9eab3e199ecfcaca2c39cfcf66ff4d7fe83c69ef
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Mon Jan 15 03:15:16 2024 +0400

    Use base aws classes in Amazon QuickSight Operators/Sensors (#36776)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     |  14 ++
 airflow/providers/amazon/aws/hooks/quicksight.py   |  51 +++--
 .../providers/amazon/aws/operators/quicksight.py   |  41 ++--
 airflow/providers/amazon/aws/sensors/quicksight.py |  58 +++--
 .../operators/quicksight.rst                       |   5 +
 tests/providers/amazon/aws/hooks/test_base_aws.py  |   4 +
 .../providers/amazon/aws/hooks/test_quicksight.py  | 245 +++++++++++++++------
 .../amazon/aws/operators/test_quicksight.py        |  42 +++-
 .../amazon/aws/sensors/test_quicksight.py          | 133 ++++++-----
 9 files changed, 402 insertions(+), 191 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index d6e0762a1a..635a874e26 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -629,6 +629,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         """Verify or not SSL certificates boto3 client/resource read-only 
property."""
         return self.conn_config.verify
 
+    @cached_property
+    def account_id(self) -> str:
+        """Return associated AWS Account ID."""
+        return (
+            self.get_session(region_name=self.region_name)
+            .client(
+                service_name="sts",
+                endpoint_url=self.conn_config.get_service_endpoint_url("sts"),
+                config=self.config,
+                verify=self.verify,
+            )
+            .get_caller_identity()["Account"]
+        )
+
     def get_session(self, region_name: str | None = None, deferrable: bool = 
False) -> boto3.session.Session:
         """Get the underlying 
boto3.session.Session(region_name=region_name)."""
         return SessionFactory(
diff --git a/airflow/providers/amazon/aws/hooks/quicksight.py 
b/airflow/providers/amazon/aws/hooks/quicksight.py
index 6ee7c5bfd4..1106a793c1 100644
--- a/airflow/providers/amazon/aws/hooks/quicksight.py
+++ b/airflow/providers/amazon/aws/hooks/quicksight.py
@@ -18,13 +18,13 @@
 from __future__ import annotations
 
 import time
+import warnings
 from functools import cached_property
 
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
 
 
 class QuickSightHook(AwsBaseHook):
@@ -46,10 +46,6 @@ class QuickSightHook(AwsBaseHook):
     def __init__(self, *args, **kwargs):
         super().__init__(client_type="quicksight", *args, **kwargs)
 
-    @cached_property
-    def sts_hook(self):
-        return StsHook(aws_conn_id=self.aws_conn_id)
-
     def create_ingestion(
         self,
         data_set_id: str,
@@ -57,6 +53,7 @@ class QuickSightHook(AwsBaseHook):
         ingestion_type: str,
         wait_for_completion: bool = True,
         check_interval: int = 30,
+        aws_account_id: str | None = None,
     ) -> dict:
         """
         Create and start a new SPICE ingestion for a dataset; refresh the 
SPICE datasets.
@@ -66,18 +63,18 @@ class QuickSightHook(AwsBaseHook):
 
         :param data_set_id:  ID of the dataset used in the ingestion.
         :param ingestion_id: ID for the ingestion.
-        :param ingestion_type: Type of ingestion . 
"INCREMENTAL_REFRESH"|"FULL_REFRESH"
+        :param ingestion_type: Type of ingestion: 
"INCREMENTAL_REFRESH"|"FULL_REFRESH"
         :param wait_for_completion: if the program should keep running until 
job finishes
         :param check_interval: the time interval in seconds which the operator
             will check the status of QuickSight Ingestion
+        :param aws_account_id: An AWS Account ID, if set to ``None`` then use 
associated AWS Account ID.
         :return: Returns descriptive information about the created data 
ingestion
             having Ingestion ARN, HTTP status, ingestion ID and ingestion 
status.
         """
+        aws_account_id = aws_account_id or self.account_id
         self.log.info("Creating QuickSight Ingestion for data set id %s.", 
data_set_id)
-        quicksight_client = self.get_conn()
         try:
-            aws_account_id = self.sts_hook.get_account_number()
-            create_ingestion_response = quicksight_client.create_ingestion(
+            create_ingestion_response = self.conn.create_ingestion(
                 DataSetId=data_set_id,
                 IngestionId=ingestion_id,
                 IngestionType=ingestion_type,
@@ -97,20 +94,21 @@ class QuickSightHook(AwsBaseHook):
             self.log.error("Failed to run Amazon QuickSight create_ingestion 
API, error: %s", general_error)
             raise
 
-    def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: 
str) -> str:
+    def get_status(self, aws_account_id: str | None, data_set_id: str, 
ingestion_id: str) -> str:
         """
         Get the current status of QuickSight Create Ingestion API.
 
         .. seealso::
             - :external+boto3:py:meth:`QuickSight.Client.describe_ingestion`
 
-        :param aws_account_id: An AWS Account ID
+        :param aws_account_id: An AWS Account ID, if set to ``None`` then use 
associated AWS Account ID.
         :param data_set_id: QuickSight Data Set ID
         :param ingestion_id: QuickSight Ingestion ID
         :return: An QuickSight Ingestion Status
         """
+        aws_account_id = aws_account_id or self.account_id
         try:
-            describe_ingestion_response = self.get_conn().describe_ingestion(
+            describe_ingestion_response = self.conn.describe_ingestion(
                 AwsAccountId=aws_account_id, DataSetId=data_set_id, 
IngestionId=ingestion_id
             )
             return describe_ingestion_response["Ingestion"]["IngestionStatus"]
@@ -119,17 +117,19 @@ class QuickSightHook(AwsBaseHook):
         except ClientError as e:
             raise AirflowException(f"AWS request failed: {e}")
 
-    def get_error_info(self, aws_account_id: str, data_set_id: str, 
ingestion_id: str) -> dict | None:
+    def get_error_info(self, aws_account_id: str | None, data_set_id: str, 
ingestion_id: str) -> dict | None:
         """
         Get info about the error if any.
 
-        :param aws_account_id: An AWS Account ID
+        :param aws_account_id: An AWS Account ID, if set to ``None`` then use 
associated AWS Account ID.
         :param data_set_id: QuickSight Data Set ID
         :param ingestion_id: QuickSight Ingestion ID
         :return: Error info dict containing the error type (key 'Type') and 
message (key 'Message')
             if available. Else, returns None.
         """
-        describe_ingestion_response = self.get_conn().describe_ingestion(
+        aws_account_id = aws_account_id or self.account_id
+
+        describe_ingestion_response = self.conn.describe_ingestion(
             AwsAccountId=aws_account_id, DataSetId=data_set_id, 
IngestionId=ingestion_id
         )
         # using .get() to get None if the key is not present, instead of an 
exception.
@@ -137,7 +137,7 @@ class QuickSightHook(AwsBaseHook):
 
     def wait_for_state(
         self,
-        aws_account_id: str,
+        aws_account_id: str | None,
         data_set_id: str,
         ingestion_id: str,
         target_state: set,
@@ -146,7 +146,7 @@ class QuickSightHook(AwsBaseHook):
         """
         Check status of a QuickSight Create Ingestion API.
 
-        :param aws_account_id: An AWS Account ID
+        :param aws_account_id: An AWS Account ID, if set to ``None`` then use 
associated AWS Account ID.
         :param data_set_id: QuickSight Data Set ID
         :param ingestion_id: QuickSight Ingestion ID
         :param target_state: Describes the QuickSight Job's Target State
@@ -154,6 +154,8 @@ class QuickSightHook(AwsBaseHook):
             will check the status of QuickSight Ingestion
         :return: response of describe_ingestion call after Ingestion is done
         """
+        aws_account_id = aws_account_id or self.account_id
+
         while True:
             status = self.get_status(aws_account_id, data_set_id, ingestion_id)
             self.log.info("Current status is %s", status)
@@ -168,3 +170,16 @@ class QuickSightHook(AwsBaseHook):
 
         self.log.info("QuickSight Ingestion completed")
         return status
+
+    @cached_property
+    def sts_hook(self):
+        warnings.warn(
+            f"`{type(self).__name__}.sts_hook` property is deprecated and will 
be removed in the future. "
+            "This property used for obtain AWS Account ID, "
+            f"please consider to use `{type(self).__name__}.account_id` 
instead",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        from airflow.providers.amazon.aws.hooks.sts import StsHook
+
+        return StsHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/operators/quicksight.py 
b/airflow/providers/amazon/aws/operators/quicksight.py
index 4268374117..9555e0d63a 100644
--- a/airflow/providers/amazon/aws/operators/quicksight.py
+++ b/airflow/providers/amazon/aws/operators/quicksight.py
@@ -18,16 +18,15 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
-DEFAULT_CONN_ID = "aws_default"
 
-
-class QuickSightCreateIngestionOperator(BaseOperator):
+class QuickSightCreateIngestionOperator(AwsBaseOperator[QuickSightHook]):
     """
     Creates and starts a new SPICE ingestion for a dataset;  also helps to 
Refresh existing SPICE datasets.
 
@@ -43,23 +42,25 @@ class QuickSightCreateIngestionOperator(BaseOperator):
         that the operation waits to check the status of the Amazon QuickSight 
Ingestion.
     :param check_interval: if wait is set to be true, this is the time interval
         in seconds which the operator will check the status of the Amazon 
QuickSight Ingestion
-    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(templated)
-         If this is None or empty then the default boto3 behaviour is used. If
-         running Airflow in a distributed manner and aws_conn_id is None or
-         empty, then the default boto3 configuration would be used (and must be
-         maintained on each worker node).
-    :param region: Which AWS region the connection should use. (templated)
-         If this is None or empty then the default boto3 behaviour is used.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = (
+    aws_hook_class = QuickSightHook
+    template_fields: Sequence[str] = aws_template_fields(
         "data_set_id",
         "ingestion_id",
         "ingestion_type",
         "wait_for_completion",
         "check_interval",
-        "aws_conn_id",
-        "region",
     )
     ui_color = "#ffd700"
 
@@ -70,26 +71,18 @@ class QuickSightCreateIngestionOperator(BaseOperator):
         ingestion_type: str = "FULL_REFRESH",
         wait_for_completion: bool = True,
         check_interval: int = 30,
-        aws_conn_id: str = DEFAULT_CONN_ID,
-        region: str | None = None,
         **kwargs,
     ):
+        super().__init__(**kwargs)
         self.data_set_id = data_set_id
         self.ingestion_id = ingestion_id
         self.ingestion_type = ingestion_type
         self.wait_for_completion = wait_for_completion
         self.check_interval = check_interval
-        self.aws_conn_id = aws_conn_id
-        self.region = region
-        super().__init__(**kwargs)
 
     def execute(self, context: Context):
-        hook = QuickSightHook(
-            aws_conn_id=self.aws_conn_id,
-            region_name=self.region,
-        )
         self.log.info("Running the Amazon QuickSight SPICE Ingestion on 
Dataset ID: %s", self.data_set_id)
-        return hook.create_ingestion(
+        return self.hook.create_ingestion(
             data_set_id=self.data_set_id,
             ingestion_id=self.ingestion_id,
             ingestion_type=self.ingestion_type,
diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py 
b/airflow/providers/amazon/aws/sensors/quicksight.py
index fc90ecbe45..ebd8310fe4 100644
--- a/airflow/providers/amazon/aws/sensors/quicksight.py
+++ b/airflow/providers/amazon/aws/sensors/quicksight.py
@@ -17,19 +17,19 @@
 # under the License.
 from __future__ import annotations
 
+import warnings
 from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class QuickSightSensor(BaseSensorOperator):
+class QuickSightSensor(AwsBaseSensor[QuickSightHook]):
     """
     Watches for the status of an Amazon QuickSight Ingestion.
 
@@ -39,27 +39,25 @@ class QuickSightSensor(BaseSensorOperator):
 
     :param data_set_id:  ID of the dataset used in the ingestion.
     :param ingestion_id: ID for the ingestion.
-    :param aws_conn_id: The Airflow connection used for AWS credentials. 
(templated)
-         If this is None or empty then the default boto3 behaviour is used. If
-         running Airflow in a distributed manner and aws_conn_id is None or
-         empty, then the default boto3 configuration would be used (and must be
-         maintained on each worker node).
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
+    aws_hook_class = QuickSightHook
     template_fields: Sequence[str] = ("data_set_id", "ingestion_id", 
"aws_conn_id")
 
-    def __init__(
-        self,
-        *,
-        data_set_id: str,
-        ingestion_id: str,
-        aws_conn_id: str = "aws_default",
-        **kwargs,
-    ) -> None:
+    def __init__(self, *, data_set_id: str, ingestion_id: str, **kwargs):
         super().__init__(**kwargs)
         self.data_set_id = data_set_id
         self.ingestion_id = ingestion_id
-        self.aws_conn_id = aws_conn_id
         self.success_status = "COMPLETED"
         self.errored_statuses = ("FAILED", "CANCELLED")
 
@@ -71,13 +69,10 @@ class QuickSightSensor(BaseSensorOperator):
         :return: True if it COMPLETED and False if not.
         """
         self.log.info("Poking for Amazon QuickSight Ingestion ID: %s", 
self.ingestion_id)
-        aws_account_id = self.sts_hook.get_account_number()
-        quicksight_ingestion_state = self.quicksight_hook.get_status(
-            aws_account_id, self.data_set_id, self.ingestion_id
-        )
+        quicksight_ingestion_state = self.hook.get_status(None, 
self.data_set_id, self.ingestion_id)
         self.log.info("QuickSight Status: %s", quicksight_ingestion_state)
         if quicksight_ingestion_state in self.errored_statuses:
-            error = self.quicksight_hook.get_error_info(aws_account_id, 
self.data_set_id, self.ingestion_id)
+            error = self.hook.get_error_info(None, self.data_set_id, 
self.ingestion_id)
             message = f"The QuickSight Ingestion failed. Error info: {error}"
             if self.soft_fail:
                 raise AirflowSkipException(message)
@@ -86,8 +81,23 @@ class QuickSightSensor(BaseSensorOperator):
 
     @cached_property
     def quicksight_hook(self):
-        return QuickSightHook(aws_conn_id=self.aws_conn_id)
+        warnings.warn(
+            f"`{type(self).__name__}.quicksight_hook` property is deprecated, "
+            f"please use `{type(self).__name__}.hook` property instead.",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.hook
 
     @cached_property
     def sts_hook(self):
+        warnings.warn(
+            f"`{type(self).__name__}.sts_hook` property is deprecated and will 
be removed in the future. "
+            "This property used for obtain AWS Account ID, "
+            f"please consider to use `{type(self).__name__}.hook.account_id` 
instead",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        from airflow.providers.amazon.aws.hooks.sts import StsHook
+
         return StsHook(aws_conn_id=self.aws_conn_id)
diff --git a/docs/apache-airflow-providers-amazon/operators/quicksight.rst 
b/docs/apache-airflow-providers-amazon/operators/quicksight.rst
index cbca98d7d5..9cc0abe337 100644
--- a/docs/apache-airflow-providers-amazon/operators/quicksight.rst
+++ b/docs/apache-airflow-providers-amazon/operators/quicksight.rst
@@ -30,6 +30,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index ba94048421..c87aaa98fd 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -1031,6 +1031,10 @@ class TestAwsBaseHook:
         assert mock_mask_secret.mock_calls == expected_calls
         assert credentials == expected_credentials
 
+    @mock_sts
+    def test_account_id(self):
+        assert AwsBaseHook(aws_conn_id=None).account_id == DEFAULT_ACCOUNT_ID
+
 
 class ThrowErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/hooks/test_quicksight.py 
b/tests/providers/amazon/aws/hooks/test_quicksight.py
index 9c8ef16ce8..6a7795843b 100644
--- a/tests/providers/amazon/aws/hooks/test_quicksight.py
+++ b/tests/providers/amazon/aws/hooks/test_quicksight.py
@@ -22,19 +22,15 @@ from unittest import mock
 import pytest
 from botocore.exceptions import ClientError
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
-
-AWS_ACCOUNT_ID = "123456789012"
 
+DEFAULT_AWS_ACCOUNT_ID = "123456789012"
 MOCK_DATA = {
     "DataSetId": "DemoDataSet",
     "IngestionId": "DemoDataSet_Ingestion",
     "IngestionType": "INCREMENTAL_REFRESH",
-    "AwsAccountId": AWS_ACCOUNT_ID,
 }
-
 MOCK_CREATE_INGESTION_RESPONSE = {
     "Status": 201,
     "Arn": 
"arn:aws:quicksight:us-east-1:123456789012:dataset/DemoDataSet/ingestion/DemoDataSet3_Ingestion",
@@ -42,7 +38,6 @@ MOCK_CREATE_INGESTION_RESPONSE = {
     "IngestionStatus": "INITIALIZED",
     "RequestId": "fc1f7eea-1327-41d6-9af7-c12f097ed343",
 }
-
 MOCK_DESCRIBE_INGESTION_SUCCESS = {
     "Status": 200,
     "Ingestion": {
@@ -59,7 +54,6 @@ MOCK_DESCRIBE_INGESTION_SUCCESS = {
     },
     "RequestId": "DemoDataSet_Ingestion_Request_ID",
 }
-
 MOCK_DESCRIBE_INGESTION_FAILURE = {
     "Status": 403,
     "Ingestion": {
@@ -76,6 +70,23 @@ MOCK_DESCRIBE_INGESTION_FAILURE = {
     },
     "RequestId": "DemoDataSet_Ingestion_Request_ID",
 }
+ACCOUNT_TEST_CASES = [
+    pytest.param(None, DEFAULT_AWS_ACCOUNT_ID, id="default-account-id"),
+    pytest.param("777777777777", "777777777777", id="custom-account-id"),
+]
+
+
+@pytest.fixture
+def mocked_account_id():
+    with mock.patch.object(QuickSightHook, "account_id", 
new_callable=mock.PropertyMock) as m:
+        m.return_value = DEFAULT_AWS_ACCOUNT_ID
+        yield m
+
+
+@pytest.fixture
+def mocked_client():
+    with mock.patch.object(QuickSightHook, "conn") as m:
+        yield m
 
 
 class TestQuicksight:
@@ -83,70 +94,174 @@ class TestQuicksight:
         hook = QuickSightHook(aws_conn_id="aws_default", 
region_name="us-east-1")
         assert hook.conn is not None
 
-    @mock.patch.object(QuickSightHook, "get_conn")
-    @mock.patch.object(StsHook, "get_conn")
-    @mock.patch.object(StsHook, "get_account_number")
-    def test_create_ingestion(self, mock_get_account_number, sts_conn, 
mock_conn):
-        mock_conn.return_value.create_ingestion.return_value = 
MOCK_CREATE_INGESTION_RESPONSE
-        mock_get_account_number.return_value = AWS_ACCOUNT_ID
-        quicksight_hook = QuickSightHook(aws_conn_id="aws_default", 
region_name="us-east-1")
-        result = quicksight_hook.create_ingestion(
-            data_set_id="DemoDataSet",
-            ingestion_id="DemoDataSet_Ingestion",
-            ingestion_type="INCREMENTAL_REFRESH",
+    @pytest.mark.parametrize(
+        "response, expected_status",
+        [
+            pytest.param(MOCK_DESCRIBE_INGESTION_SUCCESS, "COMPLETED", 
id="completed"),
+            pytest.param(MOCK_DESCRIBE_INGESTION_FAILURE, "Failed", 
id="failed"),
+        ],
+    )
+    @pytest.mark.parametrize("aws_account_id, expected_account_id", 
ACCOUNT_TEST_CASES)
+    def test_get_job_status(
+        self, response, expected_status, aws_account_id, expected_account_id, 
mocked_account_id, mocked_client
+    ):
+        """Test get job status."""
+        mocked_client.describe_ingestion.return_value = response
+
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        assert (
+            hook.get_status(
+                data_set_id="DemoDataSet",
+                ingestion_id="DemoDataSet_Ingestion",
+                aws_account_id=aws_account_id,
+            )
+            == expected_status
+        )
+        mocked_client.describe_ingestion.assert_called_with(
+            AwsAccountId=expected_account_id,
+            DataSetId="DemoDataSet",
+            IngestionId="DemoDataSet_Ingestion",
+        )
+
+    @pytest.mark.parametrize(
+        "exception, error_match",
+        [
+            pytest.param(KeyError("Foo"), "Could not get status", 
id="key-error"),
+            pytest.param(
+                ClientError(error_response={}, operation_name="fake"),
+                "AWS request failed",
+                id="botocore-client",
+            ),
+        ],
+    )
+    def test_get_job_status_exception(self, exception, error_match, 
mocked_client, mocked_account_id):
+        mocked_client.describe_ingestion.side_effect = exception
+
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        with pytest.raises(AirflowException, match=error_match):
+            assert hook.get_status(
+                data_set_id="DemoDataSet",
+                ingestion_id="DemoDataSet_Ingestion",
+                aws_account_id=None,
+            )
+
+    @pytest.mark.parametrize(
+        "error_info",
+        [
+            pytest.param({"foo": "bar"}, id="error-info-exists"),
+            pytest.param(None, id="error-info-not-exists"),
+        ],
+    )
+    @pytest.mark.parametrize("aws_account_id, expected_account_id", 
ACCOUNT_TEST_CASES)
+    def test_get_error_info(
+        self, error_info, aws_account_id, expected_account_id, mocked_client, 
mocked_account_id
+    ):
+        mocked_response = {"Ingestion": {}}
+        if error_info:
+            mocked_response["Ingestion"]["ErrorInfo"] = error_info
+        mocked_client.describe_ingestion.return_value = mocked_response
+
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        assert (
+            hook.get_error_info(
+                data_set_id="DemoDataSet", 
ingestion_id="DemoDataSet_Ingestion", aws_account_id=None
+            )
+            == error_info
         )
-        expected_call_params = MOCK_DATA
-        
mock_conn.return_value.create_ingestion.assert_called_with(**expected_call_params)
-        assert result == MOCK_CREATE_INGESTION_RESPONSE
 
-    @mock.patch.object(QuickSightHook, "get_conn")
+    @mock.patch.object(QuickSightHook, "get_status", return_value="FAILED")
+    @mock.patch.object(QuickSightHook, "get_error_info")
+    @pytest.mark.parametrize("aws_account_id, expected_account_id", 
ACCOUNT_TEST_CASES)
+    def test_wait_for_state_failure(
+        self,
+        mocked_get_error_info,
+        mocked_get_status,
+        aws_account_id,
+        expected_account_id,
+        mocked_client,
+        mocked_account_id,
+    ):
+        mocked_get_error_info.return_value = "Something Bad Happen"
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        with pytest.raises(AirflowException, match="Error info: Something Bad 
Happen"):
+            hook.wait_for_state(
+                aws_account_id, "data_set_id", "ingestion_id", 
target_state={"COMPLETED"}, check_interval=0
+            )
+        mocked_get_status.assert_called_with(expected_account_id, 
"data_set_id", "ingestion_id")
+        mocked_get_error_info.assert_called_with(expected_account_id, 
"data_set_id", "ingestion_id")
+
+    @mock.patch.object(QuickSightHook, "get_status", return_value="CANCELLED")
+    def test_wait_for_state_canceled(self, _):
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        with pytest.raises(AirflowException, match="The Amazon QuickSight 
SPICE ingestion cancelled"):
+            hook.wait_for_state(
+                "aws_account_id", "data_set_id", "ingestion_id", 
target_state={"COMPLETED"}, check_interval=0
+            )
+
     @mock.patch.object(QuickSightHook, "get_status")
-    def test_fast_failing_ingestion(self, mock_get_status, mock_conn):
-        quicksight_hook = QuickSightHook(aws_conn_id="aws_default", 
region_name="us-east-1")
-        mock_get_status.return_value = "FAILED"
-        with pytest.raises(AirflowException):
-            quicksight_hook.wait_for_state(
-                "account_id", "data_set_id", "ingestion_id", 
target_state={"COMPLETED"}, check_interval=1
+    def test_wait_for_state_completed(self, mocked_get_status):
+        mocked_get_status.side_effect = ["INITIALIZED", "QUEUED", "RUNNING", 
"COMPLETED"]
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        assert (
+            hook.wait_for_state(
+                "aws_account_id", "data_set_id", "ingestion_id", 
target_state={"COMPLETED"}, check_interval=0
             )
+            == "COMPLETED"
+        )
+        assert mocked_get_status.call_count == 4
+
+    @pytest.mark.parametrize(
+        "wait_for_completion", [pytest.param(True, id="wait"), 
pytest.param(False, id="no-wait")]
+    )
+    @pytest.mark.parametrize("aws_account_id, expected_account_id", 
ACCOUNT_TEST_CASES)
+    def test_create_ingestion(
+        self, wait_for_completion, aws_account_id, expected_account_id, 
mocked_account_id, mocked_client
+    ):
+        mocked_client.create_ingestion.return_value = 
MOCK_CREATE_INGESTION_RESPONSE
+
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        with mock.patch.object(QuickSightHook, "wait_for_state") as 
mocked_wait_for_state:
+            assert (
+                hook.create_ingestion(
+                    data_set_id="DemoDataSet",
+                    ingestion_id="DemoDataSet_Ingestion",
+                    ingestion_type="INCREMENTAL_REFRESH",
+                    aws_account_id=aws_account_id,
+                    wait_for_completion=wait_for_completion,
+                    check_interval=0,
+                )
+                == MOCK_CREATE_INGESTION_RESPONSE
+            )
+            if wait_for_completion:
+                mocked_wait_for_state.assert_called_once_with(
+                    aws_account_id=expected_account_id,
+                    data_set_id="DemoDataSet",
+                    ingestion_id="DemoDataSet_Ingestion",
+                    target_state={"COMPLETED"},
+                    check_interval=0,
+                )
+            else:
+                mocked_wait_for_state.assert_not_called()
 
-    @mock.patch.object(StsHook, "get_conn")
-    @mock.patch.object(StsHook, "get_account_number")
-    def test_create_ingestion_exception(self, mock_get_account_number, 
sts_conn):
-        mock_get_account_number.return_value = AWS_ACCOUNT_ID
-        hook = QuickSightHook(aws_conn_id="aws_default")
-        with pytest.raises(ClientError) as raised_exception:
+        
mocked_client.create_ingestion.assert_called_with(AwsAccountId=expected_account_id,
 **MOCK_DATA)
+
+    def test_create_ingestion_exception(self, mocked_account_id, 
mocked_client, caplog):
+        mocked_client.create_ingestion.side_effect = ValueError("Fake Error")
+        hook = QuickSightHook(aws_conn_id=None)
+        with pytest.raises(ValueError, match="Fake Error"):
             hook.create_ingestion(
                 data_set_id="DemoDataSet",
                 ingestion_id="DemoDataSet_Ingestion",
                 ingestion_type="INCREMENTAL_REFRESH",
             )
-        ex = raised_exception.value
-        assert ex.operation_name == "CreateIngestion"
-
-    @mock.patch.object(QuickSightHook, "get_conn")
-    def test_get_job_status(self, mock_conn):
-        """
-        Test get job status
-        """
-        mock_conn.return_value.describe_ingestion.return_value = 
MOCK_DESCRIBE_INGESTION_SUCCESS
-        quicksight_hook = QuickSightHook(aws_conn_id="aws_default", 
region_name="us-east-1")
-        result = quicksight_hook.get_status(
-            data_set_id="DemoDataSet",
-            ingestion_id="DemoDataSet_Ingestion",
-            aws_account_id="123456789012",
-        )
-        assert result == "COMPLETED"
-
-    @mock.patch.object(QuickSightHook, "get_conn")
-    def test_get_job_status_failed(self, mock_conn):
-        """
-        Test get job status
-        """
-        mock_conn.return_value.describe_ingestion.return_value = 
MOCK_DESCRIBE_INGESTION_FAILURE
-        quicksight_hook = QuickSightHook(aws_conn_id="aws_default", 
region_name="us-east-1")
-        result = quicksight_hook.get_status(
-            data_set_id="DemoDataSet",
-            ingestion_id="DemoDataSet_Ingestion",
-            aws_account_id="123456789012",
-        )
-        assert result == "Failed"
+        assert "create_ingestion API, error: Fake Error" in caplog.text
+
+    def test_deprecated_properties(self):
+        hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+        with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as 
mocked_class, pytest.warns(
+            AirflowProviderDeprecationWarning, match="consider to use 
`.*account_id` instead"
+        ):
+            mocked_sts_hook = mock.MagicMock(name="FakeStsHook")
+            mocked_class.return_value = mocked_sts_hook
+            assert hook.sts_hook is mocked_sts_hook
+            mocked_class.assert_called_once_with(aws_conn_id=None)
diff --git a/tests/providers/amazon/aws/operators/test_quicksight.py 
b/tests/providers/amazon/aws/operators/test_quicksight.py
index 2b7b0dc35f..fd30426293 100644
--- a/tests/providers/amazon/aws/operators/test_quicksight.py
+++ b/tests/providers/amazon/aws/operators/test_quicksight.py
@@ -38,17 +38,41 @@ MOCK_RESPONSE = {
 
 class TestQuickSightCreateIngestionOperator:
     def setup_method(self):
-        self.quicksight = QuickSightCreateIngestionOperator(
-            task_id="test_quicksight_operator",
-            data_set_id=DATA_SET_ID,
-            ingestion_id=INGESTION_ID,
+        self.default_op_kwargs = {
+            "task_id": "quicksight_create",
+            "aws_conn_id": None,
+            "data_set_id": DATA_SET_ID,
+            "ingestion_id": INGESTION_ID,
+        }
+
+    def test_init(self):
+        self.default_op_kwargs.pop("aws_conn_id", None)
+
+        op = QuickSightCreateIngestionOperator(
+            **self.default_op_kwargs,
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="cn-north-1",
+            verify=False,
+            botocore_config={"read_timeout": 42},
         )
+        assert op.hook.client_type == "quicksight"
+        assert op.hook.resource_type is None
+        assert op.hook.aws_conn_id == "fake-conn-id"
+        assert op.hook._region_name == "cn-north-1"
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = QuickSightCreateIngestionOperator(**self.default_op_kwargs)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
 
-    @mock.patch.object(QuickSightHook, "get_conn")
-    @mock.patch.object(QuickSightHook, "create_ingestion")
-    def test_execute(self, mock_create_ingestion, mock_client):
-        mock_create_ingestion.return_value = MOCK_RESPONSE
-        self.quicksight.execute(None)
+    @mock.patch.object(QuickSightHook, "create_ingestion", 
return_value=MOCK_RESPONSE)
+    def test_execute(self, mock_create_ingestion):
+        QuickSightCreateIngestionOperator(**self.default_op_kwargs).execute({})
         mock_create_ingestion.assert_called_once_with(
             data_set_id=DATA_SET_ID,
             ingestion_id=INGESTION_ID,
diff --git a/tests/providers/amazon/aws/sensors/test_quicksight.py 
b/tests/providers/amazon/aws/sensors/test_quicksight.py
index ba3ce83789..bef78d072d 100644
--- a/tests/providers/amazon/aws/sensors/test_quicksight.py
+++ b/tests/providers/amazon/aws/sensors/test_quicksight.py
@@ -20,10 +20,8 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
-from moto import mock_sts
-from moto.core import DEFAULT_ACCOUNT_ID
 
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
 from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor
 
@@ -31,58 +29,91 @@ DATA_SET_ID = "DemoDataSet"
 INGESTION_ID = "DemoDataSet_Ingestion"
 
 
+@pytest.fixture
+def mocked_get_status():
+    with mock.patch.object(QuickSightHook, "get_status") as m:
+        yield m
+
+
+@pytest.fixture
+def mocked_get_error_info():
+    with mock.patch.object(QuickSightHook, "get_error_info") as m:
+        yield m
+
+
 class TestQuickSightSensor:
     def setup_method(self):
-        self.sensor = QuickSightSensor(
-            task_id="test_quicksight_sensor",
-            aws_conn_id="aws_default",
-            data_set_id="DemoDataSet",
-            ingestion_id="DemoDataSet_Ingestion",
+        self.default_op_kwargs = {
+            "task_id": "quicksight_sensor",
+            "aws_conn_id": None,
+            "data_set_id": DATA_SET_ID,
+            "ingestion_id": INGESTION_ID,
+        }
+
+    def test_init(self):
+        self.default_op_kwargs.pop("aws_conn_id", None)
+
+        sensor = QuickSightSensor(
+            **self.default_op_kwargs,
+            # Generic hooks parameters
+            aws_conn_id="fake-conn-id",
+            region_name="ca-west-1",
+            verify=True,
+            botocore_config={"read_timeout": 42},
         )
+        assert sensor.hook.client_type == "quicksight"
+        assert sensor.hook.resource_type is None
+        assert sensor.hook.aws_conn_id == "fake-conn-id"
+        assert sensor.hook._region_name == "ca-west-1"
+        assert sensor.hook._verify is True
+        assert sensor.hook._config is not None
+        assert sensor.hook._config.read_timeout == 42
+
+        sensor = QuickSightSensor(**self.default_op_kwargs)
+        assert sensor.hook.aws_conn_id == "aws_default"
+        assert sensor.hook._region_name is None
+        assert sensor.hook._verify is None
+        assert sensor.hook._config is None
+
+    @pytest.mark.parametrize("status", ["COMPLETED"])
+    def test_poke_completed(self, status, mocked_get_status):
+        mocked_get_status.return_value = status
+        assert QuickSightSensor(**self.default_op_kwargs).poke({}) is True
+        mocked_get_status.assert_called_once_with(None, DATA_SET_ID, 
INGESTION_ID)
 
-    @mock_sts
-    @mock.patch.object(QuickSightHook, "get_status")
-    def test_poke_success(self, mock_get_status):
-        mock_get_status.return_value = "COMPLETED"
-        assert self.sensor.poke({}) is True
-        mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, 
DATA_SET_ID, INGESTION_ID)
-
-    @mock_sts
-    @mock.patch.object(QuickSightHook, "get_status")
-    @mock.patch.object(QuickSightHook, "get_error_info")
-    def test_poke_cancelled(self, _, mock_get_status):
-        mock_get_status.return_value = "CANCELLED"
-        with pytest.raises(AirflowException):
-            self.sensor.poke({})
-        mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, 
DATA_SET_ID, INGESTION_ID)
-
-    @mock_sts
-    @mock.patch.object(QuickSightHook, "get_status")
-    @mock.patch.object(QuickSightHook, "get_error_info")
-    def test_poke_failed(self, _, mock_get_status):
-        mock_get_status.return_value = "FAILED"
-        with pytest.raises(AirflowException):
-            self.sensor.poke({})
-        mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, 
DATA_SET_ID, INGESTION_ID)
-
-    @mock_sts
-    @mock.patch.object(QuickSightHook, "get_status")
-    def test_poke_initialized(self, mock_get_status):
-        mock_get_status.return_value = "INITIALIZED"
-        assert self.sensor.poke({}) is False
-        mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID, 
DATA_SET_ID, INGESTION_ID)
+    @pytest.mark.parametrize("status", ["INITIALIZED"])
+    def test_poke_not_completed(self, status, mocked_get_status):
+        mocked_get_status.return_value = status
+        assert QuickSightSensor(**self.default_op_kwargs).poke({}) is False
+        mocked_get_status.assert_called_once_with(None, DATA_SET_ID, 
INGESTION_ID)
 
+    @pytest.mark.parametrize("status", ["FAILED", "CANCELLED"])
     @pytest.mark.parametrize(
-        "soft_fail, expected_exception", ((False, AirflowException), (True, 
AirflowSkipException))
+        "soft_fail, expected_exception",
+        [
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+            pytest.param(False, AirflowException, id="non-soft-fail"),
+        ],
     )
-    
@mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook.get_account_number")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_status")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_error_info")
-    def test_fail_poke(self, get_error_info, get_status, _, soft_fail, 
expected_exception):
-        self.sensor.soft_fail = soft_fail
-        error = "expected error"
-        message = f"The QuickSight Ingestion failed. Error info: {error}"
-        with pytest.raises(expected_exception, match=message):
-            get_status.return_value = "FAILED"
-            get_error_info.return_value = message
-            self.sensor.poke(context={})
+    def test_poke_terminated_status(
+        self, status, soft_fail, expected_exception, mocked_get_status, 
mocked_get_error_info
+    ):
+        mocked_get_status.return_value = status
+        mocked_get_error_info.return_value = "something bad happen"
+        with pytest.raises(expected_exception, match="Error info: something 
bad happen"):
+            QuickSightSensor(**self.default_op_kwargs, 
soft_fail=soft_fail).poke({})
+        mocked_get_status.assert_called_once_with(None, DATA_SET_ID, 
INGESTION_ID)
+        mocked_get_error_info.assert_called_once_with(None, DATA_SET_ID, 
INGESTION_ID)
+
+    def test_deprecated_properties(self):
+        sensor = QuickSightSensor(**self.default_op_kwargs)
+        with pytest.warns(AirflowProviderDeprecationWarning, match="please use 
`.*hook` property instead"):
+            assert sensor.quicksight_hook is sensor.hook
+
+        with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as 
mocked_class, pytest.warns(
+            AirflowProviderDeprecationWarning, match="consider to use 
`.*hook\.account_id` instead"
+        ):
+            mocked_sts_hook = mock.MagicMock(name="FakeStsHook")
+            mocked_class.return_value = mocked_sts_hook
+            assert sensor.sts_hook is mocked_sts_hook
+            mocked_class.assert_called_once_with(aws_conn_id=None)


Reply via email to