This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 137042831a8 Add MwaaDagRunSensor to Amazon Provider Package (#46945)
137042831a8 is described below
commit 137042831a8e9e0d7b9ccc8eb1e18346b114cb10
Author: Ramit Kataria <[email protected]>
AuthorDate: Mon Feb 24 13:51:25 2025 -0800
Add MwaaDagRunSensor to Amazon Provider Package (#46945)
Includes the doc page, unit tests and system test.
Support for deferrable mode will be added soon.
---
providers/amazon/docs/operators/mwaa.rst | 30 ++++--
providers/amazon/provider.yaml | 3 +
.../airflow/providers/amazon/aws/sensors/mwaa.py | 113 +++++++++++++++++++++
.../airflow/providers/amazon/get_provider_info.py | 4 +
.../amazon/tests/system/amazon/aws/example_mwaa.py | 13 ++-
.../tests/unit/amazon/aws/sensors/test_mwaa.py | 75 ++++++++++++++
6 files changed, 231 insertions(+), 7 deletions(-)
diff --git a/providers/amazon/docs/operators/mwaa.rst
b/providers/amazon/docs/operators/mwaa.rst
index 021998b0a10..7eb9bab3983 100644
--- a/providers/amazon/docs/operators/mwaa.rst
+++ b/providers/amazon/docs/operators/mwaa.rst
@@ -24,6 +24,9 @@ is a managed service for Apache Airflow that lets you use
your current, familiar
your workflows. You gain improved scalability, availability, and security
without the operational burden of managing
underlying infrastructure.
+Note: Unlike Airflow's built-in operators, these operators are meant for
interaction with external Airflow environments
+hosted on AWS MWAA.
+
Prerequisite Tasks
------------------
@@ -45,12 +48,8 @@ Trigger a DAG run in an Amazon MWAA environment
To trigger a DAG run in an Amazon MWAA environment you can use the
:class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator`
-Note: Unlike
:class:`~airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunOperator`,
this operator is capable of
-triggering a DAG in a separate Airflow environment as long as the environment
with the DAG being triggered is running on
-AWS MWAA.
-
-In the following example, the task ``trigger_dag_run`` triggers a dag run for
a DAG with with the ID ``hello_world`` in
-the environment ``MyAirflowEnvironment``.
+In the following example, the task ``trigger_dag_run`` triggers a DAG run for
the DAG ``hello_world`` in the environment
+``MyAirflowEnvironment``.
.. exampleinclude::
/../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
:language: python
@@ -58,6 +57,25 @@ the environment ``MyAirflowEnvironment``.
:start-after: [START howto_operator_mwaa_trigger_dag_run]
:end-before: [END howto_operator_mwaa_trigger_dag_run]
+Sensors
+-------
+
+.. _howto/sensor:MwaaDagRunSensor:
+
+Wait on the state of an AWS MWAA DAG Run
+========================================
+
+To wait for a DAG Run running on Amazon MWAA until it reaches one of the given
states, you can use the
+:class:`~airflow.providers.amazon.aws.sensors.mwaa.MwaaDagRunSensor`
+
+In the following example, the task ``wait_for_dag_run`` waits for the DAG run
created in the above task to complete.
+
+.. exampleinclude::
/../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_mwaa_dag_run]
+ :end-before: [END howto_sensor_mwaa_dag_run]
+
References
----------
diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml
index 038ad0db3a1..ca2077bb794 100644
--- a/providers/amazon/provider.yaml
+++ b/providers/amazon/provider.yaml
@@ -485,6 +485,9 @@ sensors:
- integration-name: Amazon Managed Service for Apache Flink
python-modules:
- airflow.providers.amazon.aws.sensors.kinesis_analytics
+ - integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.mwaa
- integration-name: Amazon OpenSearch Serverless
python-modules:
- airflow.providers.amazon.aws.sensors.opensearch_serverless
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
new file mode 100644
index 00000000000..9007379e22c
--- /dev/null
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Collection, Sequence
+from typing import TYPE_CHECKING
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.utils.state import State
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
+ """
+ Waits for a DAG Run in an MWAA Environment to complete.
+
+ If the DAG Run fails, an AirflowException is thrown.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the
guide:
+ :ref:`howto/sensor:MwaaDagRunSensor`
+
+ :param external_env_name: The external MWAA environment name that contains
the DAG Run you want to wait for
+ (templated)
+ :param external_dag_id: The DAG ID in the external MWAA environment that
contains the DAG Run you want to wait for
+ (templated)
+ :param external_dag_run_id: The DAG Run ID in the external MWAA
environment that you want to wait for (templated)
+ :param success_states: Collection of DAG Run states that would make this
task marked as successful, default is
+ ``airflow.utils.state.State.success_states`` (templated)
+ :param failure_states: Collection of DAG Run states that would make this
task marked as failed and raise an
+ AirflowException, default is
``airflow.utils.state.State.failed_states`` (templated)
+ """
+
+ aws_hook_class = MwaaHook
+ template_fields: Sequence[str] = aws_template_fields(
+ "external_env_name",
+ "external_dag_id",
+ "external_dag_run_id",
+ "success_states",
+ "failure_states",
+ )
+
+ def __init__(
+ self,
+ *,
+ external_env_name: str,
+ external_dag_id: str,
+ external_dag_run_id: str,
+ success_states: Collection[str] | None = None,
+ failure_states: Collection[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.success_states = set(success_states if success_states else
State.success_states)
+ self.failure_states = set(failure_states if failure_states else
State.failed_states)
+
+ if len(self.success_states & self.failure_states):
+ raise AirflowException("allowed_states and failed_states must not
have any values in common")
+
+ self.external_env_name = external_env_name
+ self.external_dag_id = external_dag_id
+ self.external_dag_run_id = external_dag_run_id
+
+ def poke(self, context: Context) -> bool:
+ self.log.info(
+ "Poking for DAG run %s of DAG %s in MWAA environment %s",
+ self.external_dag_run_id,
+ self.external_dag_id,
+ self.external_env_name,
+ )
+ response = self.hook.invoke_rest_api(
+ env_name=self.external_env_name,
+
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}",
+ method="GET",
+ )
+
+ # If RestApiStatusCode == 200, the RestApiResponse must have the
"state" key, otherwise something terrible has
+ # happened in the API and KeyError would be raised
+ # If RestApiStatusCode >= 300, a botocore exception would've already
been raised during the
+ # self.hook.invoke_rest_api call
+ # The scope of this sensor is going to only be raising
AirflowException due to failure of the DAGRun
+
+ state = response["RestApiResponse"]["state"]
+ if state in self.success_states:
+ return True
+
+ if state in self.failure_states:
+ raise AirflowException(
+ f"The DAG run {self.external_dag_run_id} of DAG
{self.external_dag_id} in MWAA environment {self.external_env_name} "
+ f"failed with state {state}."
+ )
+ return False
diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
index f0d0c47d1a1..88b6c0ef55c 100644
--- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
+++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py
@@ -604,6 +604,10 @@ def get_provider_info():
"integration-name": "Amazon Managed Service for Apache Flink",
"python-modules":
["airflow.providers.amazon.aws.sensors.kinesis_analytics"],
},
+ {
+ "integration-name": "Amazon Managed Workflows for Apache
Airflow (MWAA)",
+ "python-modules":
["airflow.providers.amazon.aws.sensors.mwaa"],
+ },
{
"integration-name": "Amazon OpenSearch Serverless",
"python-modules":
["airflow.providers.amazon.aws.sensors.opensearch_serverless"],
diff --git a/providers/amazon/tests/system/amazon/aws/example_mwaa.py
b/providers/amazon/tests/system/amazon/aws/example_mwaa.py
index 01fd1afcbb8..fb8a2d220df 100644
--- a/providers/amazon/tests/system/amazon/aws/example_mwaa.py
+++ b/providers/amazon/tests/system/amazon/aws/example_mwaa.py
@@ -21,6 +21,7 @@ from datetime import datetime
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.mwaa import
MwaaTriggerDagRunOperator
+from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
from system.amazon.aws.utils import SystemTestContextBuilder
DAG_ID = "example_mwaa"
@@ -29,7 +30,6 @@ DAG_ID = "example_mwaa"
EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME"
EXISTING_DAG_ID_KEY = "DAG_ID"
-
sys_test_context_task = (
SystemTestContextBuilder()
# NOTE: Creating a functional MWAA environment is time-consuming and
requires
@@ -67,11 +67,22 @@ with DAG(
)
# [END howto_operator_mwaa_trigger_dag_run]
+ # [START howto_sensor_mwaa_dag_run]
+ wait_for_dag_run = MwaaDagRunSensor(
+ task_id="wait_for_dag_run",
+ external_env_name=env_name,
+ external_dag_id=trigger_dag_id,
+ external_dag_run_id="{{
task_instance.xcom_pull(task_ids='trigger_dag_run')['RestApiResponse']['dag_run_id']
}}",
+ poke_interval=5,
+ )
+ # [END howto_sensor_mwaa_dag_run]
+
chain(
# TEST SETUP
test_context,
# TEST BODY
trigger_dag_run,
+ wait_for_dag_run,
)
from tests_common.test_utils.watcher import watcher
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
new file mode 100644
index 00000000000..8ab39ecf1ad
--- /dev/null
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
+from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
+from airflow.utils.state import State
+
+SENSOR_KWARGS = {
+ "task_id": "test_mwaa_sensor",
+ "external_env_name": "test_env",
+ "external_dag_id": "test_dag",
+ "external_dag_run_id": "test_run_id",
+}
+
+
[email protected]
+def mock_invoke_rest_api():
+ with mock.patch.object(MwaaHook, "invoke_rest_api") as m:
+ yield m
+
+
+class TestMwaaDagRunSuccessSensor:
+ def test_init_success(self):
+ success_states = {"state1", "state2"}
+ failure_states = {"state3", "state4"}
+ sensor = MwaaDagRunSensor(
+ **SENSOR_KWARGS, success_states=success_states,
failure_states=failure_states
+ )
+ assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"]
+ assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"]
+ assert sensor.external_dag_run_id ==
SENSOR_KWARGS["external_dag_run_id"]
+ assert set(sensor.success_states) == success_states
+ assert set(sensor.failure_states) == failure_states
+
+ def test_init_failure(self):
+ with pytest.raises(AirflowException):
+ MwaaDagRunSensor(
+ **SENSOR_KWARGS, success_states={"state1", "state2"},
failure_states={"state2", "state3"}
+ )
+
+ @pytest.mark.parametrize("status", sorted(State.success_states))
+ def test_poke_completed(self, mock_invoke_rest_api, status):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
+ assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+
+ @pytest.mark.parametrize("status", ["running", "queued"])
+ def test_poke_not_completed(self, mock_invoke_rest_api, status):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
+ assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
+
+ @pytest.mark.parametrize("status", sorted(State.failed_states))
+ def test_poke_terminated(self, mock_invoke_rest_api, status):
+ mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
status}}
+ with pytest.raises(AirflowException):
+ MwaaDagRunSensor(**SENSOR_KWARGS).poke({})