This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f0b91ac6a7 Add `deferrable` param in `EmrContainerSensor` (#30945)
f0b91ac6a7 is described below
commit f0b91ac6a75a9f6f74663f8300078db09337cb16
Author: Pankaj Singh <[email protected]>
AuthorDate: Tue Jun 20 01:08:13 2023 +0530
Add `deferrable` param in `EmrContainerSensor` (#30945)
* Add deferrable param in emr container sensor
Add the deferrable param in EmrContainerSensor.
This will allow running EmrContainerSensor in an async way
that means we only submit a job from the worker to run a job
then defer to the trigger for polling and wait for a job the job status
and the worker slot won't be occupied for the whole period of
task execution.
---
airflow/providers/amazon/aws/sensors/emr.py | 38 ++++++-
airflow/providers/amazon/aws/triggers/emr.py | 75 +++++++++++++-
.../amazon/aws/waiters/emr-containers.json | 30 ++++++
.../amazon/aws/sensors/test_emr_containers.py | 13 ++-
tests/providers/amazon/aws/triggers/test_emr.py | 112 ++++++++++++++++++++-
5 files changed, 260 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/emr.py
b/airflow/providers/amazon/aws/sensors/emr.py
index 6272c048c9..f7644e379d 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable, Sequence
@@ -25,6 +26,7 @@ from deprecated import deprecated
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink,
EmrLogsLink, get_log_uri
+from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -241,6 +243,7 @@ class EmrContainerSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param poll_interval: Time in seconds to wait between two consecutive call
to
check query status on athena, defaults to 10
+ :param deferrable: Run sensor in the deferrable mode.
"""
INTERMEDIATE_STATES = (
@@ -267,6 +270,7 @@ class EmrContainerSensor(BaseSensorOperator):
max_retries: int | None = None,
aws_conn_id: str = "aws_default",
poll_interval: int = 10,
+ deferrable: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -275,6 +279,11 @@ class EmrContainerSensor(BaseSensorOperator):
self.job_id = job_id
self.poll_interval = poll_interval
self.max_retries = max_retries
+ self.deferrable = deferrable
+
+ @cached_property
+ def hook(self) -> EmrContainerHook:
+ return EmrContainerHook(self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id)
def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(
@@ -290,10 +299,31 @@ class EmrContainerSensor(BaseSensorOperator):
return False
return True
- @cached_property
- def hook(self) -> EmrContainerHook:
- """Create and return an EmrContainerHook."""
- return EmrContainerHook(self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id)
+ def execute(self, context: Context):
+ if not self.deferrable:
+ super().execute(context=context)
+ else:
+ timeout = (
+ timedelta(seconds=self.max_retries * self.poll_interval + 60)
+ if self.max_retries
+ else self.execution_timeout
+ )
+ self.defer(
+ timeout=timeout,
+ trigger=EmrContainerSensorTrigger(
+ virtual_cluster_id=self.virtual_cluster_id,
+ job_id=self.job_id,
+ aws_conn_id=self.aws_conn_id,
+ poll_interval=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ else:
+ self.log.info(event["message"])
class EmrNotebookExecutionSensor(EmrBaseSensor):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py
b/airflow/providers/amazon/aws/triggers/emr.py
index 1c3c8bb833..6ea2c3b25c 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -17,12 +17,13 @@
from __future__ import annotations
import asyncio
-from typing import Any
+from functools import cached_property
+from typing import Any, AsyncIterator
from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict
@@ -246,3 +247,73 @@ class EmrTerminateJobFlowTrigger(BaseTrigger):
"message": "JobFlow terminated successfully",
}
)
+
+
+class EmrContainerSensorTrigger(BaseTrigger):
+ """
+ Poll for the status of EMR container until reaches terminal state.
+
+ :param virtual_cluster_id: Reference Emr cluster id
+ :param job_id: job_id to check the state
+ :param aws_conn_id: Reference to AWS connection id
+ :param poll_interval: polling period in seconds to check for the status
+ """
+
+ def __init__(
+ self,
+ virtual_cluster_id: str,
+ job_id: str,
+ aws_conn_id: str = "aws_default",
+ poll_interval: int = 30,
+ **kwargs: Any,
+ ):
+ self.virtual_cluster_id = virtual_cluster_id
+ self.job_id = job_id
+ self.aws_conn_id = aws_conn_id
+ self.poll_interval = poll_interval
+ super().__init__(**kwargs)
+
+ @cached_property
+ def hook(self) -> EmrContainerHook:
+ return EmrContainerHook(self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id)
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes EmrContainerSensorTrigger arguments and classpath."""
+ return (
+
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
+ {
+ "virtual_cluster_id": self.virtual_cluster_id,
+ "job_id": self.job_id,
+ "aws_conn_id": self.aws_conn_id,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ async with self.hook.async_conn as client:
+ waiter = self.hook.get_waiter("container_job_complete",
deferrable=True, client=client)
+ attempt = 0
+ while True:
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ id=self.job_id,
+ virtualClusterId=self.virtual_cluster_id,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": 1,
+ },
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent({"status": "failure", "message":
f"Job Failed: {error}"})
+ break
+ self.log.info(
+ "Job status is %s. Retrying attempt %s",
+ error.last_response["jobRun"]["state"],
+ attempt,
+ )
+ await asyncio.sleep(int(self.poll_interval))
+
+ yield TriggerEvent({"status": "success", "job_id": self.job_id})
diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json
b/airflow/providers/amazon/aws/waiters/emr-containers.json
new file mode 100644
index 0000000000..a4174b0536
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/emr-containers.json
@@ -0,0 +1,30 @@
+{
+ "version": 2,
+ "waiters": {
+ "container_job_complete": {
+ "operation": "DescribeJobRun",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "COMPLETED",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "jobRun.state",
+ "expected": "CANCELLED",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/sensors/test_emr_containers.py
b/tests/providers/amazon/aws/sensors/test_emr_containers.py
index 38d7688f66..0df3657288 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_containers.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_containers.py
@@ -21,9 +21,10 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor
+from airflow.providers.amazon.aws.triggers.emr import EmrContainerSensorTrigger
class TestEmrContainerSensor:
@@ -73,3 +74,13 @@ class TestEmrContainerSensor:
with pytest.raises(AirflowException) as ctx:
self.sensor.poke(None)
assert "EMR Containers sensor failed" in str(ctx.value)
+
+
@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor.poke")
+ def test_sensor_defer(self, mock_poke):
+ self.sensor.deferrable = True
+ mock_poke.return_value = False
+ with pytest.raises(TaskDeferred) as exc:
+ self.sensor.execute(context=None)
+ assert isinstance(
+ exc.value.trigger, EmrContainerSensorTrigger
+ ), "Trigger is not a EmrContainerSensorTrigger"
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py
b/tests/providers/amazon/aws/triggers/test_emr.py
index 5d599801a2..86e54cb94a 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -24,13 +24,21 @@ from botocore.exceptions import WaiterError
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrHook
-from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger,
EmrTerminateJobFlowTrigger
+from airflow.providers.amazon.aws.triggers.emr import (
+ EmrContainerSensorTrigger,
+ EmrCreateJobFlowTrigger,
+ EmrTerminateJobFlowTrigger,
+)
from airflow.triggers.base import TriggerEvent
TEST_JOB_FLOW_ID = "test-job-flow-id"
TEST_POLL_INTERVAL = 10
TEST_MAX_ATTEMPTS = 10
TEST_AWS_CONN_ID = "test-aws-id"
+VIRTUAL_CLUSTER_ID = "vzwemreks"
+JOB_ID = "job-1234"
+AWS_CONN_ID = "aws_emr_conn"
+POLL_INTERVAL = 60
class TestEmrCreateJobFlowTrigger:
@@ -350,3 +358,105 @@ class TestEmrTerminateJobFlowTrigger:
assert str(exc.value) == f"JobFlow termination failed: {error_failed}"
assert mock_get_waiter().wait.call_count == 3
+
+
+class TestEmrContainerSensorTrigger:
+ def test_emr_container_sensor_trigger_serialize(self):
+ emr_trigger = EmrContainerSensorTrigger(
+ virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+ job_id=JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ )
+ class_path, args = emr_trigger.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger"
+ assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID
+ assert args["job_id"] == JOB_ID
+ assert args["aws_conn_id"] == AWS_CONN_ID
+ assert args["poll_interval"] == POLL_INTERVAL
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+ async def test_emr_container_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ emr_trigger = EmrContainerSensorTrigger(
+ virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+ job_id=JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ )
+
+ generator = emr_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "job_id":
JOB_ID})
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+ async def test_emr_trigger_run_multiple_attempts(self, mock_async_conn,
mock_get_waiter, mock_sleep):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ error = WaiterError(
+ name="test_name",
+ reason="test_reason",
+ last_response={"jobRun": {"state": "RUNNING"}},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
+ mock_sleep.return_value = True
+
+ emr_trigger = EmrContainerSensorTrigger(
+ virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+ job_id=JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ )
+
+ generator = emr_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent({"status": "success", "job_id":
JOB_ID})
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn")
+ async def test_emr_trigger_run_attempts_failed(self, mock_async_conn,
mock_get_waiter, mock_sleep):
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ error_available = WaiterError(
+ name="test_name",
+ reason="Max attempts exceeded",
+ last_response={"jobRun": {"state": "FAILED"}},
+ )
+ error_failed = WaiterError(
+ name="test_name",
+ reason="Waiter encountered a terminal failure state",
+ last_response={"jobRun": {"state": "FAILED"}},
+ )
+ mock_get_waiter().wait.side_effect = AsyncMock(
+ side_effect=[error_available, error_available, error_failed]
+ )
+ mock_sleep.return_value = True
+
+ emr_trigger = EmrContainerSensorTrigger(
+ virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+ job_id=JOB_ID,
+ aws_conn_id=AWS_CONN_ID,
+ poll_interval=POLL_INTERVAL,
+ )
+
+ generator = emr_trigger.run()
+ response = await generator.asend(None)
+
+ assert mock_get_waiter().wait.call_count == 3
+ assert response == TriggerEvent({"status": "failure", "message": f"Job
Failed: {error_failed}"})