This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0df2f594140 Add best-effort cleanup to EmrCreateJobFlowOperator on
post-creation failure (#61010)
0df2f594140 is described below
commit 0df2f594140270f1442b0b5afa09320fa1c14bda
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Feb 10 16:27:43 2026 +0000
Add best-effort cleanup to EmrCreateJobFlowOperator on post-creation
failure (#61010)
---
.../airflow/providers/amazon/aws/operators/emr.py | 120 +++++++++++++--------
.../aws/operators/test_emr_create_job_flow.py | 76 +++++++++++++
2 files changed, 151 insertions(+), 45 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 6241436ad74..3c9d0c75d48 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -24,6 +24,8 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Any
from uuid import uuid4
+from botocore.exceptions import WaiterError
+
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook,
EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import (
EmrClusterLink,
@@ -665,6 +667,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
:param deferrable: If True, the operator will wait asynchronously for the
crawl to complete.
This implies waiting for completion. This mode requires aiobotocore
module to be installed.
(default: False)
+ :param terminate_job_flow_on_failure: If True, attempts best-effort
termination of the EMR job flow
+ when a failure occurs after the job flow has been created. Cleanup
failures do not mask the
+ original exception. (default: True)
"""
aws_hook_class = EmrHook
@@ -691,6 +696,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
waiter_max_attempts: int | None = None,
waiter_delay: int | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ terminate_job_flow_on_failure: bool = True,
**kwargs: Any,
):
super().__init__(**kwargs)
@@ -699,6 +705,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
self.waiter_max_attempts = waiter_max_attempts or 60
self.waiter_delay = waiter_delay or 60
self.deferrable = deferrable
+ self.terminate_job_flow_on_failure = terminate_job_flow_on_failure
self.wait_policy = wait_policy
# Backwards-compatible default: if the user requested waiting for
@@ -746,58 +753,81 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
self._job_flow_id = response["JobFlowId"]
self.log.info("Job flow with id %s created", self._job_flow_id)
- EmrClusterLink.persist(
- context=context,
- operator=self,
- region_name=self.hook.conn_region_name,
- aws_partition=self.hook.conn_partition,
- job_flow_id=self._job_flow_id,
- )
- if self._job_flow_id:
- EmrLogsLink.persist(
+ try:
+ EmrClusterLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_flow_id=self._job_flow_id,
- log_uri=get_log_uri(emr_client=self.hook.conn,
job_flow_id=self._job_flow_id),
)
- if self.wait_for_completion:
- # Determine which waiter to use. Prefer explicit wait_policy when
provided,
- # otherwise default to WAIT_FOR_COMPLETION.
- wp = self.wait_policy
- if wp is not None:
- waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
- else:
- waiter_name =
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
-
- if self.deferrable:
- # Pass the selected waiter_name to the trigger so deferrable
mode waits
- # according to the requested policy as well.
- self.defer(
- trigger=EmrCreateJobFlowTrigger(
- job_flow_id=self._job_flow_id,
- aws_conn_id=self.aws_conn_id,
- waiter_delay=self.waiter_delay,
- waiter_max_attempts=self.waiter_max_attempts,
- waiter_name=waiter_name,
- ),
- method_name="execute_complete",
- # timeout is set to ensure that if a trigger dies, the
timeout does not restart
- # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
- timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
- )
- else:
- self.hook.get_waiter(waiter_name).wait(
- ClusterId=self._job_flow_id,
- WaiterConfig=prune_dict(
- {
- "Delay": self.waiter_delay,
- "MaxAttempts": self.waiter_max_attempts,
- }
- ),
+ if self._job_flow_id:
+ EmrLogsLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ job_flow_id=self._job_flow_id,
+ log_uri=get_log_uri(emr_client=self.hook.conn,
job_flow_id=self._job_flow_id),
)
- return self._job_flow_id
+ if self.wait_for_completion:
+ # Determine which waiter to use. Prefer explicit wait_policy
when provided,
+ # otherwise default to WAIT_FOR_COMPLETION.
+ wp = self.wait_policy
+ if wp is not None:
+ waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
+ else:
+ waiter_name =
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
+
+ if self.deferrable:
+ # Pass the selected waiter_name to the trigger so
deferrable mode waits
+ # according to the requested policy as well.
+ self.defer(
+ trigger=EmrCreateJobFlowTrigger(
+ job_flow_id=self._job_flow_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ waiter_name=waiter_name,
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the
timeout does not restart
+ # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
+ else:
+ self.hook.get_waiter(waiter_name).wait(
+ ClusterId=self._job_flow_id,
+ WaiterConfig=prune_dict(
+ {
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
+ }
+ ),
+ )
+ return self._job_flow_id
+
+ # Best-effort cleanup when post-creation steps fail (e.g.
IAM/permission errors).
+ except WaiterError:
+ if self._job_flow_id:
+ if self.terminate_job_flow_on_failure:
+ self.log.warning(
+ "Task failed after creating EMR job flow %s.",
+ self._job_flow_id,
+ )
+ try:
+ self.log.info(
+ "Attempting termination of EMR job flow %s.",
+ self._job_flow_id,
+ )
+
+
self.hook.conn.terminate_job_flows(JobFlowIds=[self._job_flow_id])
+ except Exception:
+ self.log.exception(
+ "Failed to terminate EMR job flow %s after task
failure.",
+ self._job_flow_id,
+ )
+ raise
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
validated_event = validate_execute_complete_event(event)
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index 8389979ec95..e51535b5768 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -23,6 +23,7 @@ from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
+from botocore.exceptions import ClientError, WaiterError
from botocore.waiter import Waiter
from jinja2 import StrictUndefined
@@ -231,6 +232,7 @@ class TestEmrCreateJobFlowOperator:
self.operator.deferrable = True
self.operator.wait_for_completion = True
+
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)
@@ -281,3 +283,77 @@ class TestEmrCreateJobFlowOperator:
)
assert getattr(op, "wait_policy") ==
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
assert op.wait_for_completion is True
+
+ def test_cleanup_on_post_create_failure(self, mocked_hook_client):
+ """
+ Ensure that if the job flow is created successfully but a subsequent
+ post-create step fails (e.g. waiter / DescribeCluster),
+ the operator attempts best-effort cleanup.
+ """
+ mocked_hook_client.run_job_flow.return_value =
RUN_JOB_FLOW_SUCCESS_RETURN
+
+ self.operator.wait_for_completion = True
+ self.operator.terminate_job_flow_on_failure = True
+
+ waiter_error = WaiterError(
+ "ClusterRunning",
+ "You are not authorized to perform this operation",
+ {},
+ )
+
+ with (
+ patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
+ patch.object(self.operator.hook.conn, "terminate_job_flows") as
mock_terminate,
+ ):
+ mock_get_waiter.return_value.wait.side_effect = waiter_error
+
+ with pytest.raises(WaiterError) as exc:
+ self.operator.execute(self.mock_context)
+
+ # Original exception must be propagated unchanged
+ assert exc.value is waiter_error
+
+ # Cleanup must be attempted
+ mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])
+
+ def test_cleanup_failure_does_not_mask_original_exception(self,
mocked_hook_client):
+ """
+ Ensure that failure during cleanup does not override
+ the original post-create exception.
+ """
+ mocked_hook_client.run_job_flow.return_value =
RUN_JOB_FLOW_SUCCESS_RETURN
+
+ self.operator.wait_for_completion = True
+ self.operator.terminate_job_flow_on_failure = True
+
+ waiter_error = WaiterError(
+ "ClusterRunning",
+ "You are not authorized to perform this operation",
+ {},
+ )
+
+ cleanup_error = ClientError(
+ error_response={
+ "Error": {
+ "Code": "UnauthorizedOperation",
+ "Message": "You are not authorized to perform this
operation",
+ }
+ },
+ operation_name="TerminateJobFlows",
+ )
+
+ with (
+ patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
+ patch.object(self.operator.hook.conn, "terminate_job_flows") as
mock_terminate,
+ ):
+ mock_get_waiter.return_value.wait.side_effect = waiter_error
+ mock_terminate.side_effect = cleanup_error
+
+ with pytest.raises(WaiterError) as exc:
+ self.operator.execute(self.mock_context)
+
+ # Original exception must be preserved
+ assert exc.value is waiter_error
+
+ # Cleanup attempted despite failure
+ mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])