This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f0467c9fd6 Use a AwsBaseWaiterTrigger-based trigger in
EmrAddStepsOperator deferred mode (#34216)
f0467c9fd6 is described below
commit f0467c9fd65e7146b44fc8f9fccb9ad750592371
Author: Pavel Yermalovich <[email protected]>
AuthorDate: Mon Sep 11 13:00:04 2023 +0200
Use a AwsBaseWaiterTrigger-based trigger in EmrAddStepsOperator deferred
mode (#34216)
---
airflow/providers/amazon/aws/operators/emr.py | 12 +-
airflow/providers/amazon/aws/triggers/emr.py | 91 ++++---------
airflow/providers/amazon/aws/waiters/emr.json | 31 +++++
tests/providers/amazon/aws/hooks/test_emr.py | 1 +
tests/providers/amazon/aws/triggers/test_emr.py | 8 ++
.../amazon/aws/triggers/test_emr_trigger.py | 144 ---------------------
.../amazon/aws/waiters/test_custom_waiters.py | 73 +++++++++++
7 files changed, 147 insertions(+), 213 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index 1bf2375a16..77e0167c21 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -100,8 +100,8 @@ class EmrAddStepsOperator(BaseOperator):
aws_conn_id: str = "aws_default",
steps: list[dict] | str | None = None,
wait_for_completion: bool = False,
- waiter_delay: int | None = 30,
- waiter_max_attempts: int | None = 60,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
execution_role_arn: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
@@ -172,8 +172,8 @@ class EmrAddStepsOperator(BaseOperator):
job_flow_id=job_flow_id,
step_ids=step_ids,
aws_conn_id=self.aws_conn_id,
- max_attempts=self.waiter_max_attempts,
- poll_interval=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
),
method_name="execute_complete",
)
@@ -182,10 +182,10 @@ class EmrAddStepsOperator(BaseOperator):
def execute_complete(self, context, event=None):
if event["status"] != "success":
- raise AirflowException(f"Error resuming cluster: {event}")
+ raise AirflowException(f"Error while running steps: {event}")
else:
self.log.info("Steps completed successfully")
- return event["step_ids"]
+ return event["value"]
class EmrStartNotebookExecutionOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py
b/airflow/providers/amazon/aws/triggers/emr.py
index a928255fe0..32f9049155 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -16,90 +16,55 @@
# under the License.
from __future__ import annotations
-import asyncio
import warnings
-from typing import TYPE_CHECKING, Any
-
-from botocore.exceptions import WaiterError
+from typing import TYPE_CHECKING
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
-from airflow.triggers.base import BaseTrigger, TriggerEvent
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
-class EmrAddStepsTrigger(BaseTrigger):
+class EmrAddStepsTrigger(AwsBaseWaiterTrigger):
"""
- Asynchronously poll the boto3 API and wait for the steps to finish
executing.
+ Poll for the status of EMR steps until they reach terminal state.
+
+ :param job_flow_id: job_flow_id which contains the steps to check the
state of
+ :param step_ids: steps to check the state of
+ :param waiter_delay: polling period in seconds to check for the status
+ :param waiter_max_attempts: The maximum number of attempts to be made
+ :param aws_conn_id: Reference to AWS connection id
- :param job_flow_id: The id of the job flow.
- :param step_ids: The id of the steps being waited upon.
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempts: The maximum number of attempts to be made.
- :param aws_conn_id: The Airflow connection used for AWS credentials.
"""
def __init__(
self,
job_flow_id: str,
step_ids: list[str],
- aws_conn_id: str,
- max_attempts: int | None,
- poll_interval: int | None,
+ waiter_delay: int,
+ waiter_max_attempts: int,
+ aws_conn_id: str = "aws_default",
):
- self.job_flow_id = job_flow_id
- self.step_ids = step_ids
- self.aws_conn_id = aws_conn_id
- self.max_attempts = max_attempts
- self.poll_interval = poll_interval
-
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger",
- {
- "job_flow_id": str(self.job_flow_id),
- "step_ids": self.step_ids,
- "poll_interval": str(self.poll_interval),
- "max_attempts": str(self.max_attempts),
- "aws_conn_id": str(self.aws_conn_id),
- },
+ super().__init__(
+ serialized_fields={"job_flow_id": job_flow_id, "step_ids":
step_ids},
+ waiter_name="steps_wait_for_terminal",
+ waiter_args={"ClusterId": job_flow_id, "StepIds": step_ids},
+ failure_message=f"Error while waiting for steps {step_ids} to
complete",
+ status_message=f"Step ids: {step_ids}, Steps are still in
non-terminal state",
+ status_queries=[
+ "Steps[].Status.State",
+ "Steps[].Status.FailureDetails",
+ ],
+ return_value=step_ids,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
)
- async def run(self):
- self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
- async with self.hook.async_conn as client:
- for step_id in self.step_ids:
- waiter = client.get_waiter("step_complete")
- for attempt in range(1, 1 + self.max_attempts):
- try:
- await waiter.wait(
- ClusterId=self.job_flow_id,
- StepId=step_id,
- WaiterConfig={
- "Delay": int(self.poll_interval),
- "MaxAttempts": 1,
- },
- )
- break
- except WaiterError as error:
- if "terminal failure" in str(error):
- yield TriggerEvent(
- {"status": "failure", "message": f"Step
{step_id} failed: {error}"}
- )
- break
- self.log.info(
- "Status of step is %s - %s",
- error.last_response["Step"]["Status"]["State"],
-
error.last_response["Step"]["Status"]["StateChangeReason"],
- )
- await asyncio.sleep(int(self.poll_interval))
- if attempt >= int(self.max_attempts):
- yield TriggerEvent({"status": "failure", "message": "Steps failed:
max attempts reached"})
- else:
- yield TriggerEvent({"status": "success", "message": "Steps
completed", "step_ids": self.step_ids})
+ def hook(self) -> AwsGenericHook:
+ return EmrHook(aws_conn_id=self.aws_conn_id)
class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
diff --git a/airflow/providers/amazon/aws/waiters/emr.json
b/airflow/providers/amazon/aws/waiters/emr.json
index 33a90c7751..91c902eed6 100644
--- a/airflow/providers/amazon/aws/waiters/emr.json
+++ b/airflow/providers/amazon/aws/waiters/emr.json
@@ -125,6 +125,37 @@
"state": "failure"
}
]
+ },
+ "steps_wait_for_terminal": {
+ "operation": "ListSteps",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "pathAll",
+ "argument": "Steps[].Status.State",
+ "expected": "COMPLETED",
+ "state": "success"
+ },
+ {
+ "matcher": "pathAny",
+ "argument": "Steps[].Status.State",
+ "expected": "CANCELLED",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAny",
+ "argument": "Steps[].Status.State",
+ "expected": "FAILED",
+ "state": "failure"
+ },
+ {
+ "matcher": "pathAny",
+ "argument": "Steps[].Status.State",
+ "expected": "INTERRUPTED",
+ "state": "failure"
+ }
+ ]
}
}
}
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py
b/tests/providers/amazon/aws/hooks/test_emr.py
index c68b25fbdf..b9864e84db 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -39,6 +39,7 @@ class TestEmrHook:
"notebook_running",
"notebook_stopped",
"step_wait_for_terminal",
+ "steps_wait_for_terminal",
]
assert sorted(hook.list_waiters()) == sorted([*official_waiters,
*custom_waiters])
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py
b/tests/providers/amazon/aws/triggers/test_emr.py
index f83ad8ccd0..5a1369e89d 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import pytest
from airflow.providers.amazon.aws.triggers.emr import (
+ EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrStepSensorTrigger,
@@ -40,6 +41,13 @@ class TestEmrTriggers:
@pytest.mark.parametrize(
"trigger",
[
+ EmrAddStepsTrigger(
+ job_flow_id=TEST_JOB_FLOW_ID,
+ step_ids=["my_step1", "my_step2"],
+ aws_conn_id=TEST_AWS_CONN_ID,
+ waiter_delay=TEST_POLL_INTERVAL,
+ waiter_max_attempts=TEST_MAX_ATTEMPTS,
+ ),
EmrCreateJobFlowTrigger(
job_flow_id=TEST_JOB_FLOW_ID,
aws_conn_id=TEST_AWS_CONN_ID,
diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py
b/tests/providers/amazon/aws/triggers/test_emr_trigger.py
index fe28f86edb..187a948fb5 100644
--- a/tests/providers/amazon/aws/triggers/test_emr_trigger.py
+++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py
@@ -16,21 +16,14 @@
# under the License.
from __future__ import annotations
-from unittest import mock
-from unittest.mock import AsyncMock
-
import pytest
-from botocore.exceptions import WaiterError
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.triggers.emr import (
- EmrAddStepsTrigger,
EmrContainerTrigger,
EmrCreateJobFlowTrigger,
EmrStepSensorTrigger,
EmrTerminateJobFlowTrigger,
)
-from airflow.triggers.base import TriggerEvent
TEST_JOB_FLOW_ID = "test_job_flow_id"
TEST_STEP_IDS = ["step1", "step2"]
@@ -39,143 +32,6 @@ TEST_MAX_ATTEMPTS = 10
TEST_POLL_INTERVAL = 10
-class TestEmrAddStepsTrigger:
- def test_emr_add_steps_trigger_serialize(self):
- emr_add_steps_trigger = EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=TEST_STEP_IDS,
- aws_conn_id=TEST_AWS_CONN_ID,
- max_attempts=TEST_MAX_ATTEMPTS,
- poll_interval=TEST_POLL_INTERVAL,
- )
- class_path, args = emr_add_steps_trigger.serialize()
- assert class_path ==
"airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger"
- assert args["job_flow_id"] == TEST_JOB_FLOW_ID
- assert args["step_ids"] == TEST_STEP_IDS
- assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
- assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
- assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-
- @pytest.mark.asyncio
- @mock.patch.object(EmrHook, "async_conn")
- async def test_emr_add_steps_trigger_run(self, mock_async_conn):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- a_mock.get_waiter().wait = AsyncMock()
-
- emr_add_steps_trigger = EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=TEST_STEP_IDS,
- aws_conn_id=TEST_AWS_CONN_ID,
- max_attempts=TEST_MAX_ATTEMPTS,
- poll_interval=TEST_POLL_INTERVAL,
- )
-
- generator = emr_add_steps_trigger.run()
- response = await generator.asend(None)
-
- assert response == TriggerEvent(
- {"status": "success", "message": "Steps completed", "step_ids":
TEST_STEP_IDS}
- )
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(EmrHook, "async_conn")
- async def test_emr_add_steps_trigger_run_multiple_attempts(self,
mock_async_conn, mock_sleep):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- error = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"Step": {"Status": {"State": "Running",
"StateChangeReason": "test_reason"}}},
- )
- a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True, error, error, True])
- mock_sleep.return_value = True
-
- emr_add_steps_trigger = EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=TEST_STEP_IDS,
- aws_conn_id=TEST_AWS_CONN_ID,
- max_attempts=TEST_MAX_ATTEMPTS,
- poll_interval=TEST_POLL_INTERVAL,
- )
-
- generator = emr_add_steps_trigger.run()
- response = await generator.asend(None)
-
- assert a_mock.get_waiter().wait.call_count == 6
- assert response == TriggerEvent(
- {"status": "success", "message": "Steps completed", "step_ids":
TEST_STEP_IDS}
- )
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(EmrHook, "async_conn")
- async def test_emr_add_steps_trigger_run_attempts_exceeded(self,
mock_async_conn, mock_sleep):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- error = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"Step": {"Status": {"State": "Running",
"StateChangeReason": "test_reason"}}},
- )
- a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error,
error, True])
- mock_sleep.return_value = True
-
- emr_add_steps_trigger = EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=[TEST_STEP_IDS[0]],
- aws_conn_id=TEST_AWS_CONN_ID,
- max_attempts=2,
- poll_interval=TEST_POLL_INTERVAL,
- )
-
- generator = emr_add_steps_trigger.run()
- response = await generator.asend(None)
-
- assert a_mock.get_waiter().wait.call_count == 2
- assert response == TriggerEvent(
- {"status": "failure", "message": "Steps failed: max attempts
reached"}
- )
-
- @pytest.mark.asyncio
- @mock.patch("asyncio.sleep")
- @mock.patch.object(EmrHook, "async_conn")
- async def test_emr_add_steps_trigger_run_attempts_failed(self,
mock_async_conn, mock_sleep):
- a_mock = mock.MagicMock()
- mock_async_conn.__aenter__.return_value = a_mock
- error_running = WaiterError(
- name="test_name",
- reason="test_reason",
- last_response={"Step": {"Status": {"State": "Running",
"StateChangeReason": "test_reason"}}},
- )
- error_failed = WaiterError(
- name="test_name",
- reason="Waiter encountered a terminal failure state:",
- last_response={"Step": {"Status": {"State": "FAILED",
"StateChangeReason": "test_reason"}}},
- )
- a_mock.get_waiter().wait.side_effect = AsyncMock(
- side_effect=[error_running, error_running, error_failed]
- )
- mock_sleep.return_value = True
-
- emr_add_steps_trigger = EmrAddStepsTrigger(
- job_flow_id=TEST_JOB_FLOW_ID,
- step_ids=[TEST_STEP_IDS[0]],
- aws_conn_id=TEST_AWS_CONN_ID,
- max_attempts=TEST_MAX_ATTEMPTS,
- poll_interval=TEST_POLL_INTERVAL,
- )
-
- generator = emr_add_steps_trigger.run()
- response = await generator.asend(None)
-
- assert a_mock.get_waiter().wait.call_count == 3
- assert response == TriggerEvent(
- {"status": "failure", "message": f"Step {TEST_STEP_IDS[0]} failed:
{error_failed}"}
- )
-
-
class TestEmrTriggers:
@pytest.mark.parametrize(
"trigger",
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index 229f6ce377..19f9296b6a 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+from typing import Sequence
from unittest import mock
import boto3
@@ -31,6 +32,7 @@ from airflow.providers.amazon.aws.hooks.batch_client import
BatchClientHook
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook,
EcsTaskDefinitionStates
from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.providers.amazon.aws.hooks.emr import EmrHook
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
@@ -351,3 +353,74 @@ class TestCustomBatchServiceWaiters:
with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01,
"MaxAttempts": 2})
+
+
+class TestCustomEmrServiceWaiters:
+ """Test waiters from ``amazon/aws/waiters/emr.json``."""
+
+ JOBFLOW_ID = "test_jobflow_id"
+ STEP_ID1 = "test_step_id_1"
+ STEP_ID2 = "test_step_id_2"
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.client = boto3.client("emr", region_name="eu-west-3")
+ monkeypatch.setattr(EmrHook, "conn", self.client)
+
+ @pytest.fixture
+ def mock_list_steps(self):
+ """Mock ``EmrHook.Client.list_steps`` method."""
+ with mock.patch.object(self.client, "list_steps") as m:
+ yield m
+
+ def test_service_waiters(self):
+ hook_waiters = EmrHook(aws_conn_id=None).list_waiters()
+ assert "steps_wait_for_terminal" in hook_waiters
+
+ @staticmethod
+ def list_steps(step_records: Sequence[tuple[str, str]]):
+ """
+ Helper function to generate minimal ListSteps response.
+ https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html
+ """
+ return {
+ "Steps": [
+ {
+ "Id": step_record[0],
+ "Status": {
+ "State": step_record[1],
+ },
+ }
+ for step_record in step_records
+ ],
+ }
+
+ def test_steps_succeeded(self, mock_list_steps):
+ """Test steps succeeded"""
+ mock_list_steps.side_effect = [
+ self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
+ self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
+ self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2,
"COMPLETED")]),
+ ]
+ waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+ waiter.wait(
+ ClusterId=self.JOBFLOW_ID,
+ StepIds=[self.STEP_ID1, self.STEP_ID2],
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )
+
+ def test_steps_failed(self, mock_list_steps):
+ """Test steps failed"""
+ mock_list_steps.side_effect = [
+ self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2,
"RUNNING")]),
+ self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2,
"COMPLETED")]),
+ self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2,
"COMPLETED")]),
+ ]
+ waiter =
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+
+ with pytest.raises(WaiterError, match="Waiter encountered a terminal
failure state"):
+ waiter.wait(
+ ClusterId=self.JOBFLOW_ID,
+ StepIds=[self.STEP_ID1, self.STEP_ID2],
+ WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+ )