This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f0b91ac6a7 Add `deferrable` param in `EmrContainerSensor` (#30945)
f0b91ac6a7 is described below

commit f0b91ac6a75a9f6f74663f8300078db09337cb16
Author: Pankaj Singh <[email protected]>
AuthorDate: Tue Jun 20 01:08:13 2023 +0530

    Add `deferrable` param in `EmrContainerSensor` (#30945)
    
    * Add deferrable param in emr container sensor
    
    Add the deferrable param in EmrContainerSensor.
    This will allow running EmrContainerSensor in an async way
    that means we only submit a job from the worker to run a job
    then defer to the trigger for polling and wait for a job the job status
    and the worker slot won't be occupied for the whole period of
    task execution.
---
 airflow/providers/amazon/aws/sensors/emr.py        |  38 ++++++-
 airflow/providers/amazon/aws/triggers/emr.py       |  75 +++++++++++++-
 .../amazon/aws/waiters/emr-containers.json         |  30 ++++++
 .../amazon/aws/sensors/test_emr_containers.py      |  13 ++-
 tests/providers/amazon/aws/triggers/test_emr.py    | 112 ++++++++++++++++++++-
 5 files changed, 260 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/emr.py 
b/airflow/providers/amazon/aws/sensors/emr.py
index 6272c048c9..f7644e379d 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
@@ -25,6 +26,7 @@ from deprecated import deprecated
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.links.emr import EmrClusterLink, 
EmrLogsLink, get_log_uri
+from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -241,6 +243,7 @@ class EmrContainerSensor(BaseSensorOperator):
     :param aws_conn_id: aws connection to use, defaults to 'aws_default'
     :param poll_interval: Time in seconds to wait between two consecutive call 
to
         check query status on athena, defaults to 10
+    :param deferrable: Run sensor in the deferrable mode.
     """
 
     INTERMEDIATE_STATES = (
@@ -267,6 +270,7 @@ class EmrContainerSensor(BaseSensorOperator):
         max_retries: int | None = None,
         aws_conn_id: str = "aws_default",
         poll_interval: int = 10,
+        deferrable: bool = False,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -275,6 +279,11 @@ class EmrContainerSensor(BaseSensorOperator):
         self.job_id = job_id
         self.poll_interval = poll_interval
         self.max_retries = max_retries
+        self.deferrable = deferrable
+
+    @cached_property
+    def hook(self) -> EmrContainerHook:
+        return EmrContainerHook(self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
 
     def poke(self, context: Context) -> bool:
         state = self.hook.poll_query_status(
@@ -290,10 +299,31 @@ class EmrContainerSensor(BaseSensorOperator):
             return False
         return True
 
-    @cached_property
-    def hook(self) -> EmrContainerHook:
-        """Create and return an EmrContainerHook."""
-        return EmrContainerHook(self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
+    def execute(self, context: Context):
+        if not self.deferrable:
+            super().execute(context=context)
+        else:
+            timeout = (
+                timedelta(seconds=self.max_retries * self.poll_interval + 60)
+                if self.max_retries
+                else self.execution_timeout
+            )
+            self.defer(
+                timeout=timeout,
+                trigger=EmrContainerSensorTrigger(
+                    virtual_cluster_id=self.virtual_cluster_id,
+                    job_id=self.job_id,
+                    aws_conn_id=self.aws_conn_id,
+                    poll_interval=self.poll_interval,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info(event["message"])
 
 
 class EmrNotebookExecutionSensor(EmrBaseSensor):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py 
b/airflow/providers/amazon/aws/triggers/emr.py
index 1c3c8bb833..6ea2c3b25c 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -17,12 +17,13 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Any
+from functools import cached_property
+from typing import Any, AsyncIterator
 
 from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.helpers import prune_dict
 
@@ -246,3 +247,73 @@ class EmrTerminateJobFlowTrigger(BaseTrigger):
                     "message": "JobFlow terminated successfully",
                 }
             )
+
+
+class EmrContainerSensorTrigger(BaseTrigger):
+    """
+    Poll for the status of EMR container until reaches terminal state.
+
+    :param virtual_cluster_id: Reference Emr cluster id
+    :param job_id:  job_id to check the state
+    :param aws_conn_id: Reference to AWS connection id
+    :param poll_interval: polling period in seconds to check for the status
+    """
+
+    def __init__(
+        self,
+        virtual_cluster_id: str,
+        job_id: str,
+        aws_conn_id: str = "aws_default",
+        poll_interval: int = 30,
+        **kwargs: Any,
+    ):
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.aws_conn_id = aws_conn_id
+        self.poll_interval = poll_interval
+        super().__init__(**kwargs)
+
+    @cached_property
+    def hook(self) -> EmrContainerHook:
+        return EmrContainerHook(self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes EmrContainerSensorTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
+            {
+                "virtual_cluster_id": self.virtual_cluster_id,
+                "job_id": self.job_id,
+                "aws_conn_id": self.aws_conn_id,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter("container_job_complete", 
deferrable=True, client=client)
+            attempt = 0
+            while True:
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        id=self.job_id,
+                        virtualClusterId=self.virtual_cluster_id,
+                        WaiterConfig={
+                            "Delay": self.poll_interval,
+                            "MaxAttempts": 1,
+                        },
+                    )
+                    break
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        yield TriggerEvent({"status": "failure", "message": 
f"Job Failed: {error}"})
+                        break
+                    self.log.info(
+                        "Job status is %s. Retrying attempt %s",
+                        error.last_response["jobRun"]["state"],
+                        attempt,
+                    )
+                    await asyncio.sleep(int(self.poll_interval))
+
+            yield TriggerEvent({"status": "success", "job_id": self.job_id})
diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json 
b/airflow/providers/amazon/aws/waiters/emr-containers.json
new file mode 100644
index 0000000000..a4174b0536
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/emr-containers.json
@@ -0,0 +1,30 @@
+{
+    "version": 2,
+    "waiters": {
+        "container_job_complete": {
+            "operation": "DescribeJobRun",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "path",
+                    "argument": "jobRun.state",
+                    "expected": "COMPLETED",
+                    "state": "success"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "jobRun.state",
+                    "expected": "FAILED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "path",
+                    "argument": "jobRun.state",
+                    "expected": "CANCELLED",
+                    "state": "failure"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py 
b/tests/providers/amazon/aws/sensors/test_emr_containers.py
index 38d7688f66..0df3657288 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_containers.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py
@@ -21,9 +21,10 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
 from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor
+from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
 
 
 class TestEmrContainerSensor:
@@ -73,3 +74,13 @@ class TestEmrContainerSensor:
         with pytest.raises(AirflowException) as ctx:
             self.sensor.poke(None)
         assert "EMR Containers sensor failed" in str(ctx.value)
+
+    
@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke")
+    def test_sensor_defer(self, mock_poke):
+        self.sensor.deferrable = True
+        mock_poke.return_value = False
+        with pytest.raises(TaskDeferred) as exc:
+            self.sensor.execute(context=None)
+        assert isinstance(
+            exc.value.trigger, EmrContainerSensorTrigger
+        ), "Trigger is not a EmrContainerSensorTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py 
b/tests/providers/amazon/aws/triggers/test_emr.py
index 5d599801a2..86e54cb94a 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -24,13 +24,21 @@ from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrHook
-from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger, 
EmrTerminateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import (
+    EmrContainerSensorTrigger,
+    EmrCreateJobFlowTrigger,
+    EmrTerminateJobFlowTrigger,
+)
 from airflow.triggers.base import TriggerEvent
 
 TEST_JOB_FLOW_ID = "test-job-flow-id"
 TEST_POLL_INTERVAL = 10
 TEST_MAX_ATTEMPTS = 10
 TEST_AWS_CONN_ID = "test-aws-id"
+VIRTUAL_CLUSTER_ID = "vzwemreks"
+JOB_ID = "job-1234"
+AWS_CONN_ID = "aws_emr_conn"
+POLL_INTERVAL = 60
 
 
 class TestEmrCreateJobFlowTrigger:
@@ -350,3 +358,105 @@ class TestEmrTerminateJobFlowTrigger:
 
         assert str(exc.value) == f"JobFlow termination failed: {error_failed}"
         assert mock_get_waiter().wait.call_count == 3
+
+
+class TestEmrContainerSensorTrigger:
+    def test_emr_container_sensor_trigger_serialize(self):
+        emr_trigger = EmrContainerSensorTrigger(
+            virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+            job_id=JOB_ID,
+            aws_conn_id=AWS_CONN_ID,
+            poll_interval=POLL_INTERVAL,
+        )
+        class_path, args = emr_trigger.serialize()
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger"
+        assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID
+        assert args["job_id"] == JOB_ID
+        assert args["aws_conn_id"] == AWS_CONN_ID
+        assert args["poll_interval"] == POLL_INTERVAL
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+    async def test_emr_container_trigger_run(self, mock_async_conn, 
mock_get_waiter):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        emr_trigger = EmrContainerSensorTrigger(
+            virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+            job_id=JOB_ID,
+            aws_conn_id=AWS_CONN_ID,
+            poll_interval=POLL_INTERVAL,
+        )
+
+        generator = emr_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "job_id": 
JOB_ID})
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+    async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn, 
mock_get_waiter, mock_sleep):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        error = WaiterError(
+            name="test_name",
+            reason="test_reason",
+            last_response={"jobRun": {"state": "RUNNING"}},
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
+        mock_sleep.return_value = True
+
+        emr_trigger = EmrContainerSensorTrigger(
+            virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+            job_id=JOB_ID,
+            aws_conn_id=AWS_CONN_ID,
+            poll_interval=POLL_INTERVAL,
+        )
+
+        generator = emr_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 3
+        assert response == TriggerEvent({"status": "success", "job_id": 
JOB_ID})
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+    async def test_emr_trigger_run_attempts_failed(self, mock_async_conn, 
mock_get_waiter, mock_sleep):
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        error_available = WaiterError(
+            name="test_name",
+            reason="Max attempts exceeded",
+            last_response={"jobRun": {"state": "FAILED"}},
+        )
+        error_failed = WaiterError(
+            name="test_name",
+            reason="Waiter encountered a terminal failure state",
+            last_response={"jobRun": {"state": "FAILED"}},
+        )
+        mock_get_waiter().wait.side_effect = AsyncMock(
+            side_effect=[error_available, error_available, error_failed]
+        )
+        mock_sleep.return_value = True
+
+        emr_trigger = EmrContainerSensorTrigger(
+            virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+            job_id=JOB_ID,
+            aws_conn_id=AWS_CONN_ID,
+            poll_interval=POLL_INTERVAL,
+        )
+
+        generator = emr_trigger.run()
+        response = await generator.asend(None)
+
+        assert mock_get_waiter().wait.call_count == 3
+        assert response == TriggerEvent({"status": "failure", "message": f"Job 
Failed: {error_failed}"})

Reply via email to