This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 18324dcf5fa Add deferrable support for MwaaDagRunSensor (#47527)
18324dcf5fa is described below

commit 18324dcf5fa58fe4bbc46f67b908de72fab4c950
Author: Ramit Kataria <[email protected]>
AuthorDate: Tue Mar 11 10:38:47 2025 -0700

    Add deferrable support for MwaaDagRunSensor (#47527)
    
    * Add MwaaDagRunCompletedTrigger and deferrable support for MwaaDagRunSensor
    
    Also Includes:
    - Unit tests
    - Support for `AwsGenericHook` and `AwsBaseWaiterTrigger` to allow 
overriding boto waiter config for custom success and failure states in the 
sensor
    - Changes to `aws.utils.waiter_with_logging.async_wait` to include info 
about latest response in exception
---
 providers/amazon/provider.yaml                     |   3 +
 .../airflow/providers/amazon/aws/hooks/base_aws.py |  10 +-
 .../airflow/providers/amazon/aws/sensors/mwaa.py   |  69 +++++++++--
 .../airflow/providers/amazon/aws/triggers/base.py  |  11 +-
 .../airflow/providers/amazon/aws/triggers/mwaa.py  | 129 +++++++++++++++++++++
 .../amazon/aws/utils/waiter_with_logging.py        |   7 +-
 .../airflow/providers/amazon/aws/waiters/mwaa.json |  36 ++++++
 .../airflow/providers/amazon/get_provider_info.py  |   4 +
 .../tests/unit/amazon/aws/sensors/test_mwaa.py     |  55 +++++----
 .../tests/unit/amazon/aws/triggers/test_mwaa.py    | 108 +++++++++++++++++
 10 files changed, 394 insertions(+), 38 deletions(-)

diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 0a508923a67..94ad23eab46 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -702,6 +702,9 @@ triggers:
   - integration-name: AWS Lambda
     python-modules:
       - airflow.providers.amazon.aws.triggers.lambda_function
+  - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.mwaa
   - integration-name: Amazon Managed Service for Apache Flink
     python-modules:
       - airflow.providers.amazon.aws.triggers.kinesis_analytics
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
index 862a210cd23..0230dea91f1 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -943,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         self,
         waiter_name: str,
         parameters: dict[str, str] | None = None,
+        config_overrides: dict[str, Any] | None = None,
         deferrable: bool = False,
         client=None,
     ) -> Waiter:
@@ -962,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         :param parameters: will scan the waiter config for the keys of that 
dict,
             and replace them with the corresponding value. If a custom waiter 
has
             such keys to be expanded, they need to be provided here.
+            Note: cannot be used if parameters are included in config_overrides
+        :param config_overrides: will update values of provided keys in the 
waiter's
+            config. Only specified keys will be updated.
         :param deferrable: If True, the waiter is going to be an async custom 
waiter.
             An async client must be provided in that case.
         :param client: The client to use for the waiter's operations
@@ -970,14 +974,18 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
 
         if deferrable and not client:
             raise ValueError("client must be provided for a deferrable 
waiter.")
+        if parameters is not None and config_overrides is not None and 
"acceptors" in config_overrides:
+            raise ValueError('parameters must be None when "acceptors" is 
included in config_overrides')
         # Currently, the custom waiter doesn't work with resource_type, only 
client_type is supported.
         client = client or self._client
         if self.waiter_path and (waiter_name in self._list_custom_waiters()):
             # Technically if waiter_name is in custom_waiters then 
self.waiter_path must
             # exist but MyPy doesn't like the fact that self.waiter_path could 
be None.
             with open(self.waiter_path) as config_file:
-                config = json.loads(config_file.read())
+                config: dict = json.loads(config_file.read())
 
+            if config_overrides is not None:
+                config["waiters"][waiter_name].update(config_overrides)
             config = self._apply_parameters_value(config, waiter_name, 
parameters)
             return BaseBotoWaiter(client=client, model_config=config, 
deferrable=deferrable).waiter(
                 waiter_name
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py 
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
index 9007379e22c..bbd23c60cb3 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
@@ -18,13 +18,16 @@
 from __future__ import annotations
 
 from collections.abc import Collection, Sequence
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
 from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.triggers.mwaa import 
MwaaDagRunCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
         (templated)
     :param external_dag_run_id: The DAG Run ID in the external MWAA 
environment that you want to wait for (templated)
     :param success_states: Collection of DAG Run states that would make this 
task marked as successful, default is
-        ``airflow.utils.state.State.success_states`` (templated)
+        ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
     :param failure_states: Collection of DAG Run states that would make this 
task marked as failed and raise an
-        AirflowException, default is 
``airflow.utils.state.State.failed_states`` (templated)
+        AirflowException, default is 
``{airflow.utils.state.DagRunState.FAILED}`` (templated)
+    :param deferrable: If True, the sensor will operate in deferrable mode. 
This mode requires aiobotocore
+        module to be installed.
+        (default: False, but can be overridden in config file by setting 
default_deferrable to True)
+    :param poke_interval: Polling period in seconds to check for the status of 
the job. (default: 60)
+    :param max_retries: Number of times before returning the current state. 
(default: 720)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
     aws_hook_class = MwaaHook
@@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
         "external_dag_run_id",
         "success_states",
         "failure_states",
+        "deferrable",
+        "max_retries",
+        "poke_interval",
     )
 
     def __init__(
@@ -68,19 +89,25 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
         external_dag_run_id: str,
         success_states: Collection[str] | None = None,
         failure_states: Collection[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        poke_interval: int = 60,
+        max_retries: int = 720,
         **kwargs,
     ):
         super().__init__(**kwargs)
 
-        self.success_states = set(success_states if success_states else 
State.success_states)
-        self.failure_states = set(failure_states if failure_states else 
State.failed_states)
+        self.success_states = set(success_states) if success_states else 
{DagRunState.SUCCESS.value}
+        self.failure_states = set(failure_states) if failure_states else 
{DagRunState.FAILED.value}
 
         if len(self.success_states & self.failure_states):
-            raise AirflowException("allowed_states and failed_states must not 
have any values in common")
+            raise ValueError("success_states and failure_states must not have 
any values in common")
 
         self.external_env_name = external_env_name
         self.external_dag_id = external_dag_id
         self.external_dag_run_id = external_dag_run_id
+        self.deferrable = deferrable
+        self.poke_interval = poke_interval
+        self.max_retries = max_retries
 
     def poke(self, context: Context) -> bool:
         self.log.info(
@@ -102,12 +129,32 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
         # The scope of this sensor is going to only be raising 
AirflowException due to failure of the DAGRun
 
         state = response["RestApiResponse"]["state"]
-        if state in self.success_states:
-            return True
 
         if state in self.failure_states:
             raise AirflowException(
                 f"The DAG run {self.external_dag_run_id} of DAG 
{self.external_dag_id} in MWAA environment {self.external_env_name} "
-                f"failed with state {state}."
+                f"failed with state: {state}"
             )
-        return False
+
+        return state in self.success_states
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        validate_execute_complete_event(event)
+
+    def execute(self, context: Context):
+        if self.deferrable:
+            self.defer(
+                trigger=MwaaDagRunCompletedTrigger(
+                    external_env_name=self.external_env_name,
+                    external_dag_id=self.external_dag_id,
+                    external_dag_run_id=self.external_dag_run_id,
+                    success_states=self.success_states,
+                    failure_states=self.failure_states,
+                    waiter_delay=self.poke_interval,
+                    waiter_max_attempts=self.max_retries,
+                    aws_conn_id=self.aws_conn_id,
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            super().execute(context=context)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
index f2c71a99adc..6183020a8f9 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
 
     :param waiter_delay: The amount of time in seconds to wait between 
attempts.
     :param waiter_max_attempts: The maximum number of attempts to be made.
+    :param waiter_config_overrides: A dict to update waiter's default 
configuration. Only specified keys will
+        be updated.
     :param aws_conn_id: The Airflow connection used for AWS credentials. To be 
used to build the hook.
     :param region_name: The AWS region where the resources to watch are. To be 
used to build the hook.
     :param verify: Whether or not to verify SSL certificates. To be used to 
build the hook.
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
         return_value: Any,
         waiter_delay: int,
         waiter_max_attempts: int,
+        waiter_config_overrides: dict[str, Any] | None = None,
         aws_conn_id: str | None,
         region_name: str | None = None,
         verify: bool | str | None = None,
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
         self.failure_message = failure_message
         self.status_message = status_message
         self.status_queries = status_queries
+        self.waiter_config_overrides = waiter_config_overrides
 
         self.return_key = return_key
         self.return_value = return_value
@@ -140,7 +144,12 @@ class AwsBaseWaiterTrigger(BaseTrigger):
     async def run(self) -> AsyncIterator[TriggerEvent]:
         hook = self.hook()
         async with await hook.get_async_conn() as client:
-            waiter = hook.get_waiter(self.waiter_name, deferrable=True, 
client=client)
+            waiter = hook.get_waiter(
+                self.waiter_name,
+                deferrable=True,
+                client=client,
+                config_overrides=self.waiter_config_overrides,
+            )
             await async_wait(
                 waiter,
                 self.waiter_delay,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py 
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
new file mode 100644
index 00000000000..bb6306d288e
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Collection
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.utils.state import DagRunState
+
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+
+class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
+    """
+    Trigger when an MWAA Dag Run is complete.
+
+    :param external_env_name: The external MWAA environment name that contains 
the DAG Run you want to wait for
+        (templated)
+    :param external_dag_id: The DAG ID in the external MWAA environment that 
contains the DAG Run you want to wait for
+        (templated)
+    :param external_dag_run_id: The DAG Run ID in the external MWAA 
environment that you want to wait for (templated)
+    :param success_states: Collection of DAG Run states that would make this 
task marked as successful, default is
+        ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
+    :param failure_states: Collection of DAG Run states that would make this 
task marked as failed and raise an
+        AirflowException, default is 
``{airflow.utils.state.DagRunState.FAILED}`` (templated)
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts. (default: 60)
+    :param waiter_max_attempts: The maximum number of attempts to be made. 
(default: 720)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    """
+
+    def __init__(
+        self,
+        *,
+        external_env_name: str,
+        external_dag_id: str,
+        external_dag_run_id: str,
+        success_states: Collection[str] | None = None,
+        failure_states: Collection[str] | None = None,
+        waiter_delay: int = 60,
+        waiter_max_attempts: int = 720,
+        aws_conn_id: str | None = None,
+    ) -> None:
+        self.success_states = set(success_states) if success_states else 
{DagRunState.SUCCESS.value}
+        self.failure_states = set(failure_states) if failure_states else 
{DagRunState.FAILED.value}
+
+        if len(self.success_states & self.failure_states):
+            raise ValueError("success_states and failure_states must not have 
any values in common")
+
+        in_progress_states = {s.value for s in DagRunState} - 
self.success_states - self.failure_states
+
+        super().__init__(
+            serialized_fields={
+                "external_env_name": external_env_name,
+                "external_dag_id": external_dag_id,
+                "external_dag_run_id": external_dag_run_id,
+                "success_states": success_states,
+                "failure_states": failure_states,
+            },
+            waiter_name="mwaa_dag_run_complete",
+            waiter_args={
+                "Name": external_env_name,
+                "Path": 
f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
+                "Method": "GET",
+            },
+            failure_message=f"The DAG run {external_dag_run_id} of DAG 
{external_dag_id} in MWAA environment {external_env_name} failed with state",
+            status_message="State of DAG run",
+            status_queries=["RestApiResponse.state"],
+            return_key="dag_run_id",
+            return_value=external_dag_run_id,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
+            waiter_config_overrides={
+                "acceptors": _build_waiter_acceptors(
+                    success_states=self.success_states,
+                    failure_states=self.failure_states,
+                    in_progress_states=in_progress_states,
+                )
+            },
+        )
+
+    def hook(self) -> AwsGenericHook:
+        return MwaaHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region_name,
+            verify=self.verify,
+            config=self.botocore_config,
+        )
+
+
+def _build_waiter_acceptors(
+    success_states: set[str], failure_states: set[str], in_progress_states: 
set[str]
+) -> list:
+    def build_acceptor(dag_run_state: str, state_waiter_category: str):
+        return {
+            "matcher": "path",
+            "argument": "RestApiResponse.state",
+            "expected": dag_run_state,
+            "state": state_waiter_category,
+        }
+
+    acceptors = []
+    for state_set, state_waiter_category in (
+        (success_states, "success"),
+        (failure_states, "failure"),
+        (in_progress_states, "retry"),
+    ):
+        for dag_run_state in state_set:
+            acceptors.append(build_acceptor(dag_run_state, 
state_waiter_category))
+
+    return acceptors
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
index 43d8bdf26d3..575a089382a 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -136,15 +136,16 @@ async def async_wait(
             last_response = error.last_response
 
             if "terminal failure" in error_reason:
-                log.error("%s: %s", failure_message, 
_LazyStatusFormatter(status_args, last_response))
-                raise AirflowException(f"{failure_message}: {error}")
+                raise AirflowException(
+                    f"{failure_message}: {_LazyStatusFormatter(status_args, 
last_response)}\n{error}"
+                )
 
             if (
                 "An error occurred" in error_reason
                 and isinstance(last_response.get("Error"), dict)
                 and "Code" in last_response.get("Error")
             ):
-                raise AirflowException(f"{failure_message}: {error}")
+                raise 
AirflowException(f"{failure_message}\n{last_response}\n{error}")
 
             log.info("%s: %s", status_message, 
_LazyStatusFormatter(status_args, last_response))
         else:
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json 
b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json
new file mode 100644
index 00000000000..c1de661aa7b
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json
@@ -0,0 +1,36 @@
+{
+  "version": 2,
+  "waiters": {
+    "mwaa_dag_run_complete": {
+      "delay": 60,
+      "maxAttempts": 720,
+      "operation": "InvokeRestApi",
+      "acceptors": [
+        {
+          "matcher": "path",
+          "argument": "RestApiResponse.state",
+          "expected": "queued",
+          "state": "retry"
+        },
+        {
+          "matcher": "path",
+          "argument": "RestApiResponse.state",
+          "expected": "running",
+          "state": "retry"
+        },
+        {
+          "matcher": "path",
+          "argument": "RestApiResponse.state",
+          "expected": "success",
+          "state": "success"
+        },
+        {
+          "matcher": "path",
+          "argument": "RestApiResponse.state",
+          "expected": "failed",
+          "state": "failure"
+        }
+      ]
+    }
+  }
+}
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py 
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index 19bfdb0ae35..b634eb2f8ce 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -881,6 +881,10 @@ def get_provider_info():
                 "integration-name": "AWS Lambda",
                 "python-modules": 
["airflow.providers.amazon.aws.triggers.lambda_function"],
             },
+            {
+                "integration-name": "Amazon Managed Workflows for Apache 
Airflow (MWAA)",
+                "python-modules": 
["airflow.providers.amazon.aws.triggers.mwaa"],
+            },
             {
                 "integration-name": "Amazon Managed Service for Apache Flink",
                 "python-modules": 
["airflow.providers.amazon.aws.triggers.kinesis_analytics"],
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
index 8ab39ecf1ad..345d4838412 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
@@ -23,13 +23,21 @@ import pytest
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
 from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
 
 SENSOR_KWARGS = {
     "task_id": "test_mwaa_sensor",
     "external_env_name": "test_env",
     "external_dag_id": "test_dag",
     "external_dag_run_id": "test_run_id",
+    "deferrable": False,
+    "poke_interval": 5,
+    "max_retries": 100,
+}
+
+SENSOR_STATE_KWARGS = {
+    "success_states": ["a", "b"],
+    "failure_states": ["c", "d"],
 }
 
 
@@ -41,35 +49,38 @@ def mock_invoke_rest_api():
 
 class TestMwaaDagRunSuccessSensor:
     def test_init_success(self):
-        success_states = {"state1", "state2"}
-        failure_states = {"state3", "state4"}
-        sensor = MwaaDagRunSensor(
-            **SENSOR_KWARGS, success_states=success_states, 
failure_states=failure_states
-        )
+        sensor = MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS)
         assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"]
         assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"]
         assert sensor.external_dag_run_id == 
SENSOR_KWARGS["external_dag_run_id"]
-        assert set(sensor.success_states) == success_states
-        assert set(sensor.failure_states) == failure_states
+        assert set(sensor.success_states) == 
set(SENSOR_STATE_KWARGS["success_states"])
+        assert set(sensor.failure_states) == 
set(SENSOR_STATE_KWARGS["failure_states"])
+        assert sensor.deferrable == SENSOR_KWARGS["deferrable"]
+        assert sensor.poke_interval == SENSOR_KWARGS["poke_interval"]
+        assert sensor.max_retries == SENSOR_KWARGS["max_retries"]
+
+        sensor = MwaaDagRunSensor(**SENSOR_KWARGS)
+        assert sensor.success_states == {DagRunState.SUCCESS.value}
+        assert sensor.failure_states == {DagRunState.FAILED.value}
 
     def test_init_failure(self):
-        with pytest.raises(AirflowException):
+        with pytest.raises(ValueError, 
match=r".*success_states.*failure_states.*"):
             MwaaDagRunSensor(
                 **SENSOR_KWARGS, success_states={"state1", "state2"}, 
failure_states={"state2", "state3"}
             )
 
-    @pytest.mark.parametrize("status", sorted(State.success_states))
-    def test_poke_completed(self, mock_invoke_rest_api, status):
-        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
-        assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+    @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["success_states"])
+    def test_poke_completed(self, mock_invoke_rest_api, state):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
state}}
+        assert MwaaDagRunSensor(**SENSOR_KWARGS, 
**SENSOR_STATE_KWARGS).poke({})
 
-    @pytest.mark.parametrize("status", ["running", "queued"])
-    def test_poke_not_completed(self, mock_invoke_rest_api, status):
-        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
-        assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+    @pytest.mark.parametrize("state", ["e", "f"])
+    def test_poke_not_completed(self, mock_invoke_rest_api, state):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
state}}
+        assert not MwaaDagRunSensor(**SENSOR_KWARGS, 
**SENSOR_STATE_KWARGS).poke({})
 
-    @pytest.mark.parametrize("status", sorted(State.failed_states))
-    def test_poke_terminated(self, mock_invoke_rest_api, status):
-        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
-        with pytest.raises(AirflowException):
-            MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+    @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["failure_states"])
+    def test_poke_terminated(self, mock_invoke_rest_api, state):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
state}}
+        with pytest.raises(AirflowException, match=f".*{state}.*"):
+            MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({})
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py 
b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py
new file mode 100644
index 00000000000..18c53e11f18
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.triggers.mwaa import 
MwaaDagRunCompletedTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.utils.state import DagRunState
+from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
+
+BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.mwaa."
+TRIGGER_KWARGS = {
+    "external_env_name": "test_env",
+    "external_dag_id": "test_dag",
+    "external_dag_run_id": "test_run_id",
+}
+
+
+class TestMwaaDagRunCompletedTrigger:
+    def test_init_states(self):
+        trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS)
+        assert trigger.success_states == {DagRunState.SUCCESS.value}
+        assert trigger.failure_states == {DagRunState.FAILED.value}
+        acceptors = trigger.waiter_config_overrides["acceptors"]
+        expected_acceptors = [
+            {
+                "matcher": "path",
+                "argument": "RestApiResponse.state",
+                "expected": DagRunState.SUCCESS.value,
+                "state": "success",
+            },
+            {
+                "matcher": "path",
+                "argument": "RestApiResponse.state",
+                "expected": DagRunState.FAILED.value,
+                "state": "failure",
+            },
+            {
+                "matcher": "path",
+                "argument": "RestApiResponse.state",
+                "expected": DagRunState.RUNNING.value,
+                "state": "retry",
+            },
+            {
+                "matcher": "path",
+                "argument": "RestApiResponse.state",
+                "expected": DagRunState.QUEUED.value,
+                "state": "retry",
+            },
+        ]
+        assert len(acceptors) == len(DagRunState)
+        assert {tuple(sorted(a.items())) for a in acceptors} == {
+            tuple(sorted(a.items())) for a in expected_acceptors
+        }
+
+    def test_init_fail(self):
+        with pytest.raises(ValueError, 
match=r".*success_states.*failure_states.*"):
+            MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS, success_states=("a", 
"b"), failure_states=("b", "c"))
+
+    def test_serialization(self):
+        success_states = ["a", "b"]
+        failure_states = ["c", "d"]
+        trigger = MwaaDagRunCompletedTrigger(
+            **TRIGGER_KWARGS, success_states=success_states, 
failure_states=failure_states
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == BASE_TRIGGER_CLASSPATH + 
"MwaaDagRunCompletedTrigger"
+        assert kwargs.get("external_env_name") == 
TRIGGER_KWARGS["external_env_name"]
+        assert kwargs.get("external_dag_id") == 
TRIGGER_KWARGS["external_dag_id"]
+        assert kwargs.get("external_dag_run_id") == 
TRIGGER_KWARGS["external_dag_run_id"]
+        assert kwargs.get("success_states") == success_states
+        assert kwargs.get("failure_states") == failure_states
+
+    @pytest.mark.asyncio
+    @mock.patch.object(MwaaHook, "get_waiter")
+    @mock.patch.object(MwaaHook, "async_conn")
+    async def test_run_success(self, mock_async_conn, mock_get_waiter):
+        mock_async_conn.__aenter__.return_value = mock.MagicMock()
+        mock_get_waiter().wait = AsyncMock()
+        trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS)
+
+        generator = trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent(
+            {"status": "success", "dag_run_id": 
TRIGGER_KWARGS["external_dag_run_id"]}
+        )
+        assert_expected_waiter_type(mock_get_waiter, "mwaa_dag_run_complete")
+        mock_get_waiter().wait.assert_called_once()

Reply via email to