This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0df2f594140 Add best-effort cleanup to EmrCreateJobFlowOperator on 
post-creation failure (#61010)
0df2f594140 is described below

commit 0df2f594140270f1442b0b5afa09320fa1c14bda
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Feb 10 16:27:43 2026 +0000

    Add best-effort cleanup to EmrCreateJobFlowOperator on post-creation 
failure (#61010)
---
 .../airflow/providers/amazon/aws/operators/emr.py  | 120 +++++++++++++--------
 .../aws/operators/test_emr_create_job_flow.py      |  76 +++++++++++++
 2 files changed, 151 insertions(+), 45 deletions(-)

diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 6241436ad74..3c9d0c75d48 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -24,6 +24,8 @@ from datetime import timedelta
 from typing import TYPE_CHECKING, Any
 from uuid import uuid4
 
+from botocore.exceptions import WaiterError
+
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.links.emr import (
     EmrClusterLink,
@@ -665,6 +667,9 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
     :param deferrable: If True, the operator will wait asynchronously for the 
crawl to complete.
         This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
         (default: False)
+    :param terminate_job_flow_on_failure: If True, attempts best-effort 
termination of the EMR job flow
+        when a failure occurs after the job flow has been created. Cleanup 
failures do not mask the
+        original exception. (default: True)
     """
 
     aws_hook_class = EmrHook
@@ -691,6 +696,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
         waiter_max_attempts: int | None = None,
         waiter_delay: int | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        terminate_job_flow_on_failure: bool = True,
         **kwargs: Any,
     ):
         super().__init__(**kwargs)
@@ -699,6 +705,7 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
         self.waiter_max_attempts = waiter_max_attempts or 60
         self.waiter_delay = waiter_delay or 60
         self.deferrable = deferrable
+        self.terminate_job_flow_on_failure = terminate_job_flow_on_failure
         self.wait_policy = wait_policy
 
         # Backwards-compatible default: if the user requested waiting for
@@ -746,58 +753,81 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
 
         self._job_flow_id = response["JobFlowId"]
         self.log.info("Job flow with id %s created", self._job_flow_id)
-        EmrClusterLink.persist(
-            context=context,
-            operator=self,
-            region_name=self.hook.conn_region_name,
-            aws_partition=self.hook.conn_partition,
-            job_flow_id=self._job_flow_id,
-        )
-        if self._job_flow_id:
-            EmrLogsLink.persist(
+        try:
+            EmrClusterLink.persist(
                 context=context,
                 operator=self,
                 region_name=self.hook.conn_region_name,
                 aws_partition=self.hook.conn_partition,
                 job_flow_id=self._job_flow_id,
-                log_uri=get_log_uri(emr_client=self.hook.conn, 
job_flow_id=self._job_flow_id),
             )
-        if self.wait_for_completion:
-            # Determine which waiter to use. Prefer explicit wait_policy when 
provided,
-            # otherwise default to WAIT_FOR_COMPLETION.
-            wp = self.wait_policy
-            if wp is not None:
-                waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
-            else:
-                waiter_name = 
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
-
-            if self.deferrable:
-                # Pass the selected waiter_name to the trigger so deferrable 
mode waits
-                # according to the requested policy as well.
-                self.defer(
-                    trigger=EmrCreateJobFlowTrigger(
-                        job_flow_id=self._job_flow_id,
-                        aws_conn_id=self.aws_conn_id,
-                        waiter_delay=self.waiter_delay,
-                        waiter_max_attempts=self.waiter_max_attempts,
-                        waiter_name=waiter_name,
-                    ),
-                    method_name="execute_complete",
-                    # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
-                    # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
-                    timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
-                )
-            else:
-                self.hook.get_waiter(waiter_name).wait(
-                    ClusterId=self._job_flow_id,
-                    WaiterConfig=prune_dict(
-                        {
-                            "Delay": self.waiter_delay,
-                            "MaxAttempts": self.waiter_max_attempts,
-                        }
-                    ),
+            if self._job_flow_id:
+                EmrLogsLink.persist(
+                    context=context,
+                    operator=self,
+                    region_name=self.hook.conn_region_name,
+                    aws_partition=self.hook.conn_partition,
+                    job_flow_id=self._job_flow_id,
+                    log_uri=get_log_uri(emr_client=self.hook.conn, 
job_flow_id=self._job_flow_id),
                 )
-        return self._job_flow_id
+            if self.wait_for_completion:
+                # Determine which waiter to use. Prefer explicit wait_policy 
when provided,
+                # otherwise default to WAIT_FOR_COMPLETION.
+                wp = self.wait_policy
+                if wp is not None:
+                    waiter_name = WAITER_POLICY_NAME_MAPPING[wp]
+                else:
+                    waiter_name = 
WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
+
+                if self.deferrable:
+                    # Pass the selected waiter_name to the trigger so 
deferrable mode waits
+                    # according to the requested policy as well.
+                    self.defer(
+                        trigger=EmrCreateJobFlowTrigger(
+                            job_flow_id=self._job_flow_id,
+                            aws_conn_id=self.aws_conn_id,
+                            waiter_delay=self.waiter_delay,
+                            waiter_max_attempts=self.waiter_max_attempts,
+                            waiter_name=waiter_name,
+                        ),
+                        method_name="execute_complete",
+                        # timeout is set to ensure that if a trigger dies, the 
timeout does not restart
+                        # 60 seconds is added to allow the trigger to exit 
gracefully (i.e. yield TriggerEvent)
+                        timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
+                    )
+                else:
+                    self.hook.get_waiter(waiter_name).wait(
+                        ClusterId=self._job_flow_id,
+                        WaiterConfig=prune_dict(
+                            {
+                                "Delay": self.waiter_delay,
+                                "MaxAttempts": self.waiter_max_attempts,
+                            }
+                        ),
+                    )
+            return self._job_flow_id
+
+        # Best-effort cleanup when post-creation steps fail (e.g. 
IAM/permission errors).
+        except WaiterError:
+            if self._job_flow_id:
+                if self.terminate_job_flow_on_failure:
+                    self.log.warning(
+                        "Task failed after creating EMR job flow %s.",
+                        self._job_flow_id,
+                    )
+                    try:
+                        self.log.info(
+                            "Attempting termination of EMR job flow %s.",
+                            self._job_flow_id,
+                        )
+
+                        
self.hook.conn.terminate_job_flows(JobFlowIds=[self._job_flow_id])
+                    except Exception:
+                        self.log.exception(
+                            "Failed to terminate EMR job flow %s after task 
failure.",
+                            self._job_flow_id,
+                        )
+            raise
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
         validated_event = validate_execute_complete_event(event)
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index 8389979ec95..e51535b5768 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -23,6 +23,7 @@ from unittest import mock
 from unittest.mock import MagicMock, patch
 
 import pytest
+from botocore.exceptions import ClientError, WaiterError
 from botocore.waiter import Waiter
 from jinja2 import StrictUndefined
 
@@ -231,6 +232,7 @@ class TestEmrCreateJobFlowOperator:
 
         self.operator.deferrable = True
         self.operator.wait_for_completion = True
+
         with pytest.raises(TaskDeferred) as exc:
             self.operator.execute(self.mock_context)
 
@@ -281,3 +283,77 @@ class TestEmrCreateJobFlowOperator:
         )
         assert getattr(op, "wait_policy") == 
WaitPolicy.WAIT_FOR_STEPS_COMPLETION
         assert op.wait_for_completion is True
+
+    def test_cleanup_on_post_create_failure(self, mocked_hook_client):
+        """
+        Ensure that if the job flow is created successfully but a subsequent
+        post-create step fails (e.g. waiter / DescribeCluster),
+        the operator attempts best-effort cleanup.
+        """
+        mocked_hook_client.run_job_flow.return_value = 
RUN_JOB_FLOW_SUCCESS_RETURN
+
+        self.operator.wait_for_completion = True
+        self.operator.terminate_job_flow_on_failure = True
+
+        waiter_error = WaiterError(
+            "ClusterRunning",
+            "You are not authorized to perform this operation",
+            {},
+        )
+
+        with (
+            patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
+            patch.object(self.operator.hook.conn, "terminate_job_flows") as 
mock_terminate,
+        ):
+            mock_get_waiter.return_value.wait.side_effect = waiter_error
+
+            with pytest.raises(WaiterError) as exc:
+                self.operator.execute(self.mock_context)
+
+            # Original exception must be propagated unchanged
+            assert exc.value is waiter_error
+
+            # Cleanup must be attempted
+            mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])
+
+    def test_cleanup_failure_does_not_mask_original_exception(self, 
mocked_hook_client):
+        """
+        Ensure that failure during cleanup does not override
+        the original post-create exception.
+        """
+        mocked_hook_client.run_job_flow.return_value = 
RUN_JOB_FLOW_SUCCESS_RETURN
+
+        self.operator.wait_for_completion = True
+        self.operator.terminate_job_flow_on_failure = True
+
+        waiter_error = WaiterError(
+            "ClusterRunning",
+            "You are not authorized to perform this operation",
+            {},
+        )
+
+        cleanup_error = ClientError(
+            error_response={
+                "Error": {
+                    "Code": "UnauthorizedOperation",
+                    "Message": "You are not authorized to perform this 
operation",
+                }
+            },
+            operation_name="TerminateJobFlows",
+        )
+
+        with (
+            patch.object(self.operator.hook, "get_waiter") as mock_get_waiter,
+            patch.object(self.operator.hook.conn, "terminate_job_flows") as 
mock_terminate,
+        ):
+            mock_get_waiter.return_value.wait.side_effect = waiter_error
+            mock_terminate.side_effect = cleanup_error
+
+            with pytest.raises(WaiterError) as exc:
+                self.operator.execute(self.mock_context)
+
+            # Original exception must be preserved
+            assert exc.value is waiter_error
+
+            # Cleanup attempted despite failure
+            mock_terminate.assert_called_once_with(JobFlowIds=[JOB_FLOW_ID])

Reply via email to