This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 18324dcf5fa Add deferrable support for MwaaDagRunSensor (#47527)
18324dcf5fa is described below
commit 18324dcf5fa58fe4bbc46f67b908de72fab4c950
Author: Ramit Kataria <[email protected]>
AuthorDate: Tue Mar 11 10:38:47 2025 -0700
Add deferrable support for MwaaDagRunSensor (#47527)
* Add MwaaDagRunCompletedTrigger and deferrable support for MwaaDagRunSensor
Also Includes:
- Unit tests
- Support for `AwsGenericHook` and `AwsBaseWaiterTrigger` to allow
overriding boto waiter config for custom success and failure states in the
sensor
- Changes to `aws.utils.waiter_with_logging.async_wait` to include info
about latest response in exception
---
providers/amazon/provider.yaml | 3 +
.../airflow/providers/amazon/aws/hooks/base_aws.py | 10 +-
.../airflow/providers/amazon/aws/sensors/mwaa.py | 69 +++++++++--
.../airflow/providers/amazon/aws/triggers/base.py | 11 +-
.../airflow/providers/amazon/aws/triggers/mwaa.py | 129 +++++++++++++++++++++
.../amazon/aws/utils/waiter_with_logging.py | 7 +-
.../airflow/providers/amazon/aws/waiters/mwaa.json | 36 ++++++
.../airflow/providers/amazon/get_provider_info.py | 4 +
.../tests/unit/amazon/aws/sensors/test_mwaa.py | 55 +++++----
.../tests/unit/amazon/aws/triggers/test_mwaa.py | 108 +++++++++++++++++
10 files changed, 394 insertions(+), 38 deletions(-)
diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 0a508923a67..94ad23eab46 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -702,6 +702,9 @@ triggers:
- integration-name: AWS Lambda
python-modules:
- airflow.providers.amazon.aws.triggers.lambda_function
+ - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.mwaa
- integration-name: Amazon Managed Service for Apache Flink
python-modules:
- airflow.providers.amazon.aws.triggers.kinesis_analytics
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
index 862a210cd23..0230dea91f1 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -943,6 +943,7 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
self,
waiter_name: str,
parameters: dict[str, str] | None = None,
+ config_overrides: dict[str, Any] | None = None,
deferrable: bool = False,
client=None,
) -> Waiter:
@@ -962,6 +963,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
:param parameters: will scan the waiter config for the keys of that
dict,
and replace them with the corresponding value. If a custom waiter
has
such keys to be expanded, they need to be provided here.
+ Note: cannot be used if parameters are included in config_overrides
+ :param config_overrides: will update values of provided keys in the
waiter's
+ config. Only specified keys will be updated.
:param deferrable: If True, the waiter is going to be an async custom
waiter.
An async client must be provided in that case.
:param client: The client to use for the waiter's operations
@@ -970,14 +974,18 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
if deferrable and not client:
raise ValueError("client must be provided for a deferrable
waiter.")
+ if parameters is not None and config_overrides is not None and
"acceptors" in config_overrides:
+ raise ValueError('parameters must be None when "acceptors" is
included in config_overrides')
# Currently, the custom waiter doesn't work with resource_type, only
client_type is supported.
client = client or self._client
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then
self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could
be None.
with open(self.waiter_path) as config_file:
- config = json.loads(config_file.read())
+ config: dict = json.loads(config_file.read())
+ if config_overrides is not None:
+ config["waiters"][waiter_name].update(config_overrides)
config = self._apply_parameters_value(config, waiter_name,
parameters)
return BaseBotoWaiter(client=client, model_config=config,
deferrable=deferrable).waiter(
waiter_name
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
index 9007379e22c..bbd23c60cb3 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
@@ -18,13 +18,16 @@
from __future__ import annotations
from collections.abc import Collection, Sequence
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.triggers.mwaa import
MwaaDagRunCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
(templated)
:param external_dag_run_id: The DAG Run ID in the external MWAA
environment that you want to wait for (templated)
:param success_states: Collection of DAG Run states that would make this
task marked as successful, default is
- ``airflow.utils.state.State.success_states`` (templated)
+ ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
:param failure_states: Collection of DAG Run states that would make this
task marked as failed and raise an
- AirflowException, default is
``airflow.utils.state.State.failed_states`` (templated)
+ AirflowException, default is
``{airflow.utils.state.DagRunState.FAILED}`` (templated)
+ :param deferrable: If True, the sensor will operate in deferrable mode.
This mode requires aiobotocore
+ module to be installed.
+ (default: False, but can be overridden in config file by setting
default_deferrable to True)
+ :param poke_interval: Polling period in seconds to check for the status of
the job. (default: 60)
+ :param max_retries: Number of times before returning the current state.
(default: 720)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
aws_hook_class = MwaaHook
@@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
"external_dag_run_id",
"success_states",
"failure_states",
+ "deferrable",
+ "max_retries",
+ "poke_interval",
)
def __init__(
@@ -68,19 +89,25 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
external_dag_run_id: str,
success_states: Collection[str] | None = None,
failure_states: Collection[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ poke_interval: int = 60,
+ max_retries: int = 720,
**kwargs,
):
super().__init__(**kwargs)
- self.success_states = set(success_states if success_states else
State.success_states)
- self.failure_states = set(failure_states if failure_states else
State.failed_states)
+ self.success_states = set(success_states) if success_states else
{DagRunState.SUCCESS.value}
+ self.failure_states = set(failure_states) if failure_states else
{DagRunState.FAILED.value}
if len(self.success_states & self.failure_states):
- raise AirflowException("allowed_states and failed_states must not
have any values in common")
+ raise ValueError("success_states and failure_states must not have
any values in common")
self.external_env_name = external_env_name
self.external_dag_id = external_dag_id
self.external_dag_run_id = external_dag_run_id
+ self.deferrable = deferrable
+ self.poke_interval = poke_interval
+ self.max_retries = max_retries
def poke(self, context: Context) -> bool:
self.log.info(
@@ -102,12 +129,32 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
# The scope of this sensor is going to only be raising
AirflowException due to failure of the DAGRun
state = response["RestApiResponse"]["state"]
- if state in self.success_states:
- return True
if state in self.failure_states:
raise AirflowException(
f"The DAG run {self.external_dag_run_id} of DAG
{self.external_dag_id} in MWAA environment {self.external_env_name} "
- f"failed with state {state}."
+ f"failed with state: {state}"
)
- return False
+
+ return state in self.success_states
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
+ validate_execute_complete_event(event)
+
+ def execute(self, context: Context):
+ if self.deferrable:
+ self.defer(
+ trigger=MwaaDagRunCompletedTrigger(
+ external_env_name=self.external_env_name,
+ external_dag_id=self.external_dag_id,
+ external_dag_run_id=self.external_dag_run_id,
+ success_states=self.success_states,
+ failure_states=self.failure_states,
+ waiter_delay=self.poke_interval,
+ waiter_max_attempts=self.max_retries,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ super().execute(context=context)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
index f2c71a99adc..6183020a8f9 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
:param waiter_delay: The amount of time in seconds to wait between
attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
+ :param waiter_config_overrides: A dict to update waiter's default
configuration. Only specified keys will
+ be updated.
:param aws_conn_id: The Airflow connection used for AWS credentials. To be
used to build the hook.
:param region_name: The AWS region where the resources to watch are. To be
used to build the hook.
:param verify: Whether or not to verify SSL certificates. To be used to
build the hook.
@@ -77,6 +79,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
return_value: Any,
waiter_delay: int,
waiter_max_attempts: int,
+ waiter_config_overrides: dict[str, Any] | None = None,
aws_conn_id: str | None,
region_name: str | None = None,
verify: bool | str | None = None,
@@ -91,6 +94,7 @@ class AwsBaseWaiterTrigger(BaseTrigger):
self.failure_message = failure_message
self.status_message = status_message
self.status_queries = status_queries
+ self.waiter_config_overrides = waiter_config_overrides
self.return_key = return_key
self.return_value = return_value
@@ -140,7 +144,12 @@ class AwsBaseWaiterTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]:
hook = self.hook()
async with await hook.get_async_conn() as client:
- waiter = hook.get_waiter(self.waiter_name, deferrable=True,
client=client)
+ waiter = hook.get_waiter(
+ self.waiter_name,
+ deferrable=True,
+ client=client,
+ config_overrides=self.waiter_config_overrides,
+ )
await async_wait(
waiter,
self.waiter_delay,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
new file mode 100644
index 00000000000..bb6306d288e
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from collections.abc import Collection
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.utils.state import DagRunState
+
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+
+class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
+ """
+ Trigger when an MWAA Dag Run is complete.
+
+ :param external_env_name: The external MWAA environment name that contains
the DAG Run you want to wait for
+ (templated)
+ :param external_dag_id: The DAG ID in the external MWAA environment that
contains the DAG Run you want to wait for
+ (templated)
+ :param external_dag_run_id: The DAG Run ID in the external MWAA
environment that you want to wait for (templated)
+ :param success_states: Collection of DAG Run states that would make this
task marked as successful, default is
+ ``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
+ :param failure_states: Collection of DAG Run states that would make this
task marked as failed and raise an
+ AirflowException, default is
``{airflow.utils.state.DagRunState.FAILED}`` (templated)
+ :param waiter_delay: The amount of time in seconds to wait between
attempts. (default: 60)
+ :param waiter_max_attempts: The maximum number of attempts to be made.
(default: 720)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(
+ self,
+ *,
+ external_env_name: str,
+ external_dag_id: str,
+ external_dag_run_id: str,
+ success_states: Collection[str] | None = None,
+ failure_states: Collection[str] | None = None,
+ waiter_delay: int = 60,
+ waiter_max_attempts: int = 720,
+ aws_conn_id: str | None = None,
+ ) -> None:
+ self.success_states = set(success_states) if success_states else
{DagRunState.SUCCESS.value}
+ self.failure_states = set(failure_states) if failure_states else
{DagRunState.FAILED.value}
+
+ if len(self.success_states & self.failure_states):
+ raise ValueError("success_states and failure_states must not have
any values in common")
+
+ in_progress_states = {s.value for s in DagRunState} -
self.success_states - self.failure_states
+
+ super().__init__(
+ serialized_fields={
+ "external_env_name": external_env_name,
+ "external_dag_id": external_dag_id,
+ "external_dag_run_id": external_dag_run_id,
+ "success_states": success_states,
+ "failure_states": failure_states,
+ },
+ waiter_name="mwaa_dag_run_complete",
+ waiter_args={
+ "Name": external_env_name,
+ "Path":
f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
+ "Method": "GET",
+ },
+ failure_message=f"The DAG run {external_dag_run_id} of DAG
{external_dag_id} in MWAA environment {external_env_name} failed with state",
+ status_message="State of DAG run",
+ status_queries=["RestApiResponse.state"],
+ return_key="dag_run_id",
+ return_value=external_dag_run_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ waiter_config_overrides={
+ "acceptors": _build_waiter_acceptors(
+ success_states=self.success_states,
+ failure_states=self.failure_states,
+ in_progress_states=in_progress_states,
+ )
+ },
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return MwaaHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
+
+
+def _build_waiter_acceptors(
+ success_states: set[str], failure_states: set[str], in_progress_states:
set[str]
+) -> list:
+ def build_acceptor(dag_run_state: str, state_waiter_category: str):
+ return {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": dag_run_state,
+ "state": state_waiter_category,
+ }
+
+ acceptors = []
+ for state_set, state_waiter_category in (
+ (success_states, "success"),
+ (failure_states, "failure"),
+ (in_progress_states, "retry"),
+ ):
+ for dag_run_state in state_set:
+ acceptors.append(build_acceptor(dag_run_state,
state_waiter_category))
+
+ return acceptors
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
index 43d8bdf26d3..575a089382a 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -136,15 +136,16 @@ async def async_wait(
last_response = error.last_response
if "terminal failure" in error_reason:
- log.error("%s: %s", failure_message,
_LazyStatusFormatter(status_args, last_response))
- raise AirflowException(f"{failure_message}: {error}")
+ raise AirflowException(
+ f"{failure_message}: {_LazyStatusFormatter(status_args,
last_response)}\n{error}"
+ )
if (
"An error occurred" in error_reason
and isinstance(last_response.get("Error"), dict)
and "Code" in last_response.get("Error")
):
- raise AirflowException(f"{failure_message}: {error}")
+ raise
AirflowException(f"{failure_message}\n{last_response}\n{error}")
log.info("%s: %s", status_message,
_LazyStatusFormatter(status_args, last_response))
else:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json
b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json
new file mode 100644
index 00000000000..c1de661aa7b
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/mwaa.json
@@ -0,0 +1,36 @@
+{
+ "version": 2,
+ "waiters": {
+ "mwaa_dag_run_complete": {
+ "delay": 60,
+ "maxAttempts": 720,
+ "operation": "InvokeRestApi",
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": "queued",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": "running",
+ "state": "retry"
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": "success",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": "failed",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index 19bfdb0ae35..b634eb2f8ce 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -881,6 +881,10 @@ def get_provider_info():
"integration-name": "AWS Lambda",
"python-modules":
["airflow.providers.amazon.aws.triggers.lambda_function"],
},
+ {
+ "integration-name": "Amazon Managed Workflows for Apache
Airflow (MWAA)",
+ "python-modules":
["airflow.providers.amazon.aws.triggers.mwaa"],
+ },
{
"integration-name": "Amazon Managed Service for Apache Flink",
"python-modules":
["airflow.providers.amazon.aws.triggers.kinesis_analytics"],
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
index 8ab39ecf1ad..345d4838412 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
@@ -23,13 +23,21 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState
SENSOR_KWARGS = {
"task_id": "test_mwaa_sensor",
"external_env_name": "test_env",
"external_dag_id": "test_dag",
"external_dag_run_id": "test_run_id",
+ "deferrable": False,
+ "poke_interval": 5,
+ "max_retries": 100,
+}
+
+SENSOR_STATE_KWARGS = {
+ "success_states": ["a", "b"],
+ "failure_states": ["c", "d"],
}
@@ -41,35 +49,38 @@ def mock_invoke_rest_api():
class TestMwaaDagRunSuccessSensor:
def test_init_success(self):
- success_states = {"state1", "state2"}
- failure_states = {"state3", "state4"}
- sensor = MwaaDagRunSensor(
- **SENSOR_KWARGS, success_states=success_states,
failure_states=failure_states
- )
+ sensor = MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS)
assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"]
assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"]
assert sensor.external_dag_run_id ==
SENSOR_KWARGS["external_dag_run_id"]
- assert set(sensor.success_states) == success_states
- assert set(sensor.failure_states) == failure_states
+ assert set(sensor.success_states) ==
set(SENSOR_STATE_KWARGS["success_states"])
+ assert set(sensor.failure_states) ==
set(SENSOR_STATE_KWARGS["failure_states"])
+ assert sensor.deferrable == SENSOR_KWARGS["deferrable"]
+ assert sensor.poke_interval == SENSOR_KWARGS["poke_interval"]
+ assert sensor.max_retries == SENSOR_KWARGS["max_retries"]
+
+ sensor = MwaaDagRunSensor(**SENSOR_KWARGS)
+ assert sensor.success_states == {DagRunState.SUCCESS.value}
+ assert sensor.failure_states == {DagRunState.FAILED.value}
def test_init_failure(self):
- with pytest.raises(AirflowException):
+ with pytest.raises(ValueError,
match=r".*success_states.*failure_states.*"):
MwaaDagRunSensor(
**SENSOR_KWARGS, success_states={"state1", "state2"},
failure_states={"state2", "state3"}
)
- @pytest.mark.parametrize("status", sorted(State.success_states))
- def test_poke_completed(self, mock_invoke_rest_api, status):
- mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
- assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+ @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["success_states"])
+ def test_poke_completed(self, mock_invoke_rest_api, state):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
state}}
+ assert MwaaDagRunSensor(**SENSOR_KWARGS,
**SENSOR_STATE_KWARGS).poke({})
- @pytest.mark.parametrize("status", ["running", "queued"])
- def test_poke_not_completed(self, mock_invoke_rest_api, status):
- mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
- assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+ @pytest.mark.parametrize("state", ["e", "f"])
+ def test_poke_not_completed(self, mock_invoke_rest_api, state):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
state}}
+ assert not MwaaDagRunSensor(**SENSOR_KWARGS,
**SENSOR_STATE_KWARGS).poke({})
- @pytest.mark.parametrize("status", sorted(State.failed_states))
- def test_poke_terminated(self, mock_invoke_rest_api, status):
- mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
- with pytest.raises(AirflowException):
- MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+ @pytest.mark.parametrize("state", SENSOR_STATE_KWARGS["failure_states"])
+ def test_poke_terminated(self, mock_invoke_rest_api, state):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
state}}
+ with pytest.raises(AirflowException, match=f".*{state}.*"):
+ MwaaDagRunSensor(**SENSOR_KWARGS, **SENSOR_STATE_KWARGS).poke({})
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py
new file mode 100644
index 00000000000..18c53e11f18
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_mwaa.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.triggers.mwaa import
MwaaDagRunCompletedTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.utils.state import DagRunState
+from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
+
+BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.mwaa."
+TRIGGER_KWARGS = {
+ "external_env_name": "test_env",
+ "external_dag_id": "test_dag",
+ "external_dag_run_id": "test_run_id",
+}
+
+
+class TestMwaaDagRunCompletedTrigger:
+ def test_init_states(self):
+ trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS)
+ assert trigger.success_states == {DagRunState.SUCCESS.value}
+ assert trigger.failure_states == {DagRunState.FAILED.value}
+ acceptors = trigger.waiter_config_overrides["acceptors"]
+ expected_acceptors = [
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": DagRunState.SUCCESS.value,
+ "state": "success",
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": DagRunState.FAILED.value,
+ "state": "failure",
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": DagRunState.RUNNING.value,
+ "state": "retry",
+ },
+ {
+ "matcher": "path",
+ "argument": "RestApiResponse.state",
+ "expected": DagRunState.QUEUED.value,
+ "state": "retry",
+ },
+ ]
+ assert len(acceptors) == len(DagRunState)
+ assert {tuple(sorted(a.items())) for a in acceptors} == {
+ tuple(sorted(a.items())) for a in expected_acceptors
+ }
+
+ def test_init_fail(self):
+ with pytest.raises(ValueError,
match=r".*success_states.*failure_states.*"):
+ MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS, success_states=("a",
"b"), failure_states=("b", "c"))
+
+ def test_serialization(self):
+ success_states = ["a", "b"]
+ failure_states = ["c", "d"]
+ trigger = MwaaDagRunCompletedTrigger(
+ **TRIGGER_KWARGS, success_states=success_states,
failure_states=failure_states
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == BASE_TRIGGER_CLASSPATH +
"MwaaDagRunCompletedTrigger"
+ assert kwargs.get("external_env_name") ==
TRIGGER_KWARGS["external_env_name"]
+ assert kwargs.get("external_dag_id") ==
TRIGGER_KWARGS["external_dag_id"]
+ assert kwargs.get("external_dag_run_id") ==
TRIGGER_KWARGS["external_dag_run_id"]
+ assert kwargs.get("success_states") == success_states
+ assert kwargs.get("failure_states") == failure_states
+
+ @pytest.mark.asyncio
+ @mock.patch.object(MwaaHook, "get_waiter")
+ @mock.patch.object(MwaaHook, "async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = mock.MagicMock()
+ mock_get_waiter().wait = AsyncMock()
+ trigger = MwaaDagRunCompletedTrigger(**TRIGGER_KWARGS)
+
+ generator = trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent(
+ {"status": "success", "dag_run_id":
TRIGGER_KWARGS["external_dag_run_id"]}
+ )
+ assert_expected_waiter_type(mock_get_waiter, "mwaa_dag_run_complete")
+ mock_get_waiter().wait.assert_called_once()