This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 137042831a8 Add MwaaDagRunSensor to Amazon Provider Package (#46945)
137042831a8 is described below

commit 137042831a8e9e0d7b9ccc8eb1e18346b114cb10
Author: Ramit Kataria <[email protected]>
AuthorDate: Mon Feb 24 13:51:25 2025 -0800

    Add MwaaDagRunSensor to Amazon Provider Package (#46945)
    
    Includes the doc page, unit tests and system test.
    
    Support for deferrable mode will be added soon.
---
 providers/amazon/docs/operators/mwaa.rst           |  30 ++++--
 providers/amazon/provider.yaml                     |   3 +
 .../airflow/providers/amazon/aws/sensors/mwaa.py   | 113 +++++++++++++++++++++
 .../airflow/providers/amazon/get_provider_info.py  |   4 +
 .../amazon/tests/system/amazon/aws/example_mwaa.py |  13 ++-
 .../tests/unit/amazon/aws/sensors/test_mwaa.py     |  75 ++++++++++++++
 6 files changed, 231 insertions(+), 7 deletions(-)

diff --git a/providers/amazon/docs/operators/mwaa.rst 
b/providers/amazon/docs/operators/mwaa.rst
index 021998b0a10..7eb9bab3983 100644
--- a/providers/amazon/docs/operators/mwaa.rst
+++ b/providers/amazon/docs/operators/mwaa.rst
@@ -24,6 +24,9 @@ is a managed service for Apache Airflow that lets you use 
your current, familiar
 your workflows. You gain improved scalability, availability, and security 
without the operational burden of managing
 underlying infrastructure.
 
+Note: Unlike Airflow's built-in operators, these operators are meant for 
interaction with external Airflow environments
+hosted on AWS MWAA.
+
 Prerequisite Tasks
 ------------------
 
@@ -45,12 +48,8 @@ Trigger a DAG run in an Amazon MWAA environment
 To trigger a DAG run in an Amazon MWAA environment you can use the
 :class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator`
 
-Note: Unlike 
:class:`~airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunOperator`,
 this operator is capable of
-triggering a DAG in a separate Airflow environment as long as the environment 
with the DAG being triggered is running on
-AWS MWAA.
-
-In the following example, the task ``trigger_dag_run`` triggers a dag run for 
a DAG with with the ID ``hello_world`` in
-the environment ``MyAirflowEnvironment``.
+In the following example, the task ``trigger_dag_run`` triggers a DAG run for 
the DAG ``hello_world`` in the environment
+``MyAirflowEnvironment``.
 
 .. exampleinclude:: 
/../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
     :language: python
@@ -58,6 +57,25 @@ the environment ``MyAirflowEnvironment``.
     :start-after: [START howto_operator_mwaa_trigger_dag_run]
     :end-before: [END howto_operator_mwaa_trigger_dag_run]
 
+Sensors
+-------
+
+.. _howto/sensor:MwaaDagRunSensor:
+
+Wait on the state of an AWS MWAA DAG Run
+========================================
+
+To wait for a DAG Run running on Amazon MWAA until it reaches one of the given 
states, you can use the
+:class:`~airflow.providers.amazon.aws.sensors.mwaa.MwaaDagRunSensor`
+
+In the following example, the task ``wait_for_dag_run`` waits for the DAG run 
created in the above task to complete.
+
+.. exampleinclude:: 
/../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_mwaa_dag_run]
+    :end-before: [END howto_sensor_mwaa_dag_run]
+
 References
 ----------
 
diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 038ad0db3a1..ca2077bb794 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -485,6 +485,9 @@ sensors:
   - integration-name: Amazon Managed Service for Apache Flink
     python-modules:
       - airflow.providers.amazon.aws.sensors.kinesis_analytics
+  - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
+    python-modules:
+      - airflow.providers.amazon.aws.sensors.mwaa
   - integration-name: Amazon OpenSearch Serverless
     python-modules:
       - airflow.providers.amazon.aws.sensors.opensearch_serverless
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py 
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
new file mode 100644
index 00000000000..9007379e22c
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Collection, Sequence
+from typing import TYPE_CHECKING
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.state import State
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
+    """
+    Waits for a DAG Run in an MWAA Environment to complete.
+
+    If the DAG Run fails, an AirflowException is thrown.
+
+    .. seealso::
+        For more information on how to use this sensor, take a look at the 
guide:
+        :ref:`howto/sensor:MwaaDagRunSensor`
+
+    :param external_env_name: The external MWAA environment name that contains 
the DAG Run you want to wait for
+        (templated)
+    :param external_dag_id: The DAG ID in the external MWAA environment that 
contains the DAG Run you want to wait for
+        (templated)
+    :param external_dag_run_id: The DAG Run ID in the external MWAA 
environment that you want to wait for (templated)
+    :param success_states: Collection of DAG Run states that would make this 
task marked as successful, default is
+        ``airflow.utils.state.State.success_states`` (templated)
+    :param failure_states: Collection of DAG Run states that would make this 
task marked as failed and raise an
+        AirflowException, default is 
``airflow.utils.state.State.failed_states`` (templated)
+    """
+
+    aws_hook_class = MwaaHook
+    template_fields: Sequence[str] = aws_template_fields(
+        "external_env_name",
+        "external_dag_id",
+        "external_dag_run_id",
+        "success_states",
+        "failure_states",
+    )
+
+    def __init__(
+        self,
+        *,
+        external_env_name: str,
+        external_dag_id: str,
+        external_dag_run_id: str,
+        success_states: Collection[str] | None = None,
+        failure_states: Collection[str] | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.success_states = set(success_states if success_states else 
State.success_states)
+        self.failure_states = set(failure_states if failure_states else 
State.failed_states)
+
+        if len(self.success_states & self.failure_states):
+            raise AirflowException("allowed_states and failed_states must not 
have any values in common")
+
+        self.external_env_name = external_env_name
+        self.external_dag_id = external_dag_id
+        self.external_dag_run_id = external_dag_run_id
+
+    def poke(self, context: Context) -> bool:
+        self.log.info(
+            "Poking for DAG run %s of DAG %s in MWAA environment %s",
+            self.external_dag_run_id,
+            self.external_dag_id,
+            self.external_env_name,
+        )
+        response = self.hook.invoke_rest_api(
+            env_name=self.external_env_name,
+            
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}",
+            method="GET",
+        )
+
+        # If RestApiStatusCode == 200, the RestApiResponse must have the 
"state" key, otherwise something terrible has
+        # happened in the API and KeyError would be raised
+        # If RestApiStatusCode >= 300, a botocore exception would've already 
been raised during the
+        # self.hook.invoke_rest_api call
+        # The scope of this sensor is going to only be raising 
AirflowException due to failure of the DAGRun
+
+        state = response["RestApiResponse"]["state"]
+        if state in self.success_states:
+            return True
+
+        if state in self.failure_states:
+            raise AirflowException(
+                f"The DAG run {self.external_dag_run_id} of DAG 
{self.external_dag_id} in MWAA environment {self.external_env_name} "
+                f"failed with state {state}."
+            )
+        return False
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py 
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index f0d0c47d1a1..88b6c0ef55c 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -604,6 +604,10 @@ def get_provider_info():
                 "integration-name": "Amazon Managed Service for Apache Flink",
                 "python-modules": 
["airflow.providers.amazon.aws.sensors.kinesis_analytics"],
             },
+            {
+                "integration-name": "Amazon Managed Workflows for Apache 
Airflow (MWAA)",
+                "python-modules": 
["airflow.providers.amazon.aws.sensors.mwaa"],
+            },
             {
                 "integration-name": "Amazon OpenSearch Serverless",
                 "python-modules": 
["airflow.providers.amazon.aws.sensors.opensearch_serverless"],
diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa.py 
b/providers/amazon/tests/system/amazon/aws/example_mwaa.py
index 01fd1afcbb8..fb8a2d220df 100644
--- a/providers/amazon/tests/system/amazon/aws/example_mwaa.py
+++ b/providers/amazon/tests/system/amazon/aws/example_mwaa.py
@@ -21,6 +21,7 @@ from datetime import datetime
 from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
 from airflow.providers.amazon.aws.operators.mwaa import 
MwaaTriggerDagRunOperator
+from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
 from system.amazon.aws.utils import SystemTestContextBuilder
 
 DAG_ID = "example_mwaa"
@@ -29,7 +30,6 @@ DAG_ID = "example_mwaa"
 EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME"
 EXISTING_DAG_ID_KEY = "DAG_ID"
 
-
 sys_test_context_task = (
     SystemTestContextBuilder()
     # NOTE: Creating a functional MWAA environment is time-consuming and 
requires
@@ -67,11 +67,22 @@ with DAG(
     )
     # [END howto_operator_mwaa_trigger_dag_run]
 
+    # [START howto_sensor_mwaa_dag_run]
+    wait_for_dag_run = MwaaDagRunSensor(
+        task_id="wait_for_dag_run",
+        external_env_name=env_name,
+        external_dag_id=trigger_dag_id,
+        external_dag_run_id="{{ 
task_instance.xcom_pull(task_ids='trigger_dag_run')['RestApiResponse']['dag_run_id']
 }}",
+        poke_interval=5,
+    )
+    # [END howto_sensor_mwaa_dag_run]
+
     chain(
         # TEST SETUP
         test_context,
         # TEST BODY
         trigger_dag_run,
+        wait_for_dag_run,
     )
 
     from tests_common.test_utils.watcher import watcher
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py 
b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
new file mode 100644
index 00000000000..8ab39ecf1ad
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
+from airflow.utils.state import State
+
+SENSOR_KWARGS = {
+    "task_id": "test_mwaa_sensor",
+    "external_env_name": "test_env",
+    "external_dag_id": "test_dag",
+    "external_dag_run_id": "test_run_id",
+}
+
+
[email protected]
+def mock_invoke_rest_api():
+    with mock.patch.object(MwaaHook, "invoke_rest_api") as m:
+        yield m
+
+
+class TestMwaaDagRunSuccessSensor:
+    def test_init_success(self):
+        success_states = {"state1", "state2"}
+        failure_states = {"state3", "state4"}
+        sensor = MwaaDagRunSensor(
+            **SENSOR_KWARGS, success_states=success_states, 
failure_states=failure_states
+        )
+        assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"]
+        assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"]
+        assert sensor.external_dag_run_id == 
SENSOR_KWARGS["external_dag_run_id"]
+        assert set(sensor.success_states) == success_states
+        assert set(sensor.failure_states) == failure_states
+
+    def test_init_failure(self):
+        with pytest.raises(AirflowException):
+            MwaaDagRunSensor(
+                **SENSOR_KWARGS, success_states={"state1", "state2"}, 
failure_states={"state2", "state3"}
+            )
+
+    @pytest.mark.parametrize("status", sorted(State.success_states))
+    def test_poke_completed(self, mock_invoke_rest_api, status):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
+        assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+
+    @pytest.mark.parametrize("status", ["running", "queued"])
+    def test_poke_not_completed(self, mock_invoke_rest_api, status):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
+        assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+
+    @pytest.mark.parametrize("status", sorted(State.failed_states))
+    def test_poke_terminated(self, mock_invoke_rest_api, status):
+        mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": 
status}}
+        with pytest.raises(AirflowException):
+            MwaaDagRunSensor(**SENSOR_KWARGS).poke({})

Reply via email to