This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b7408270b86 Only defer `EmrCreateJobFlowOperator` when `wait_policy`
is set (#56077)
b7408270b86 is described below
commit b7408270b8660bf78a809040fb15f389445070eb
Author: LAKSH KRISHNA SHARMA
<[email protected]>
AuthorDate: Sat Sep 27 01:49:12 2025 +0530
Only defer `EmrCreateJobFlowOperator` when `wait_policy` is set (#56077)
* fixing emr operator deferral logic
* fixed failing test ruff lint
---
.../airflow/providers/amazon/aws/operators/emr.py | 46 +++++++++++-----------
.../aws/operators/test_emr_create_job_flow.py | 15 ++++++-
2 files changed, 38 insertions(+), 23 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 4377ccd89bc..23487b687ba 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -748,30 +748,32 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
job_flow_id=self._job_flow_id,
log_uri=get_log_uri(emr_client=self.hook.conn,
job_flow_id=self._job_flow_id),
)
- if self.deferrable:
- self.defer(
- trigger=EmrCreateJobFlowTrigger(
- job_flow_id=self._job_flow_id,
- aws_conn_id=self.aws_conn_id,
- waiter_delay=self.waiter_delay,
- waiter_max_attempts=self.waiter_max_attempts,
- ),
- method_name="execute_complete",
- # timeout is set to ensure that if a trigger dies, the timeout
does not restart
- # 60 seconds is added to allow the trigger to exit gracefully
(i.e. yield TriggerEvent)
- timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
- )
if self.wait_policy:
waiter_name = WAITER_POLICY_NAME_MAPPING[self.wait_policy]
- self.hook.get_waiter(waiter_name).wait(
- ClusterId=self._job_flow_id,
- WaiterConfig=prune_dict(
- {
- "Delay": self.waiter_delay,
- "MaxAttempts": self.waiter_max_attempts,
- }
- ),
- )
+
+ if self.deferrable:
+ self.defer(
+ trigger=EmrCreateJobFlowTrigger(
+ job_flow_id=self._job_flow_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ method_name="execute_complete",
+ # timeout is set to ensure that if a trigger dies, the
timeout does not restart
+ # 60 seconds is added to allow the trigger to exit
gracefully (i.e. yield TriggerEvent)
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay + 60),
+ )
+ else:
+ self.hook.get_waiter(waiter_name).wait(
+ ClusterId=self._job_flow_id,
+ WaiterConfig=prune_dict(
+ {
+ "Delay": self.waiter_delay,
+ "MaxAttempts": self.waiter_max_attempts,
+ }
+ ),
+ )
return self._job_flow_id
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> str:
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index 43624ef501d..437d3c52d28 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -238,11 +238,12 @@ class TestEmrCreateJobFlowOperator:
def test_create_job_flow_deferrable(self, mocked_hook_client):
"""
Test to make sure that the operator raises a TaskDeferred exception
- if run in deferrable mode.
+ if run in deferrable mode and wait_policy is set.
"""
mocked_hook_client.run_job_flow.return_value =
RUN_JOB_FLOW_SUCCESS_RETURN
self.operator.deferrable = True
+ self.operator.wait_policy = WaitPolicy.WAIT_FOR_COMPLETION
with pytest.raises(TaskDeferred) as exc:
self.operator.execute(self.mock_context)
@@ -250,5 +251,17 @@ class TestEmrCreateJobFlowOperator:
"Trigger is not a EmrCreateJobFlowTrigger"
)
+ def test_create_job_flow_deferrable_no_wait(self, mocked_hook_client):
+ """
+ Test to make sure that the operator does NOT raise a TaskDeferred
exception
+ if run in deferrable mode but wait_policy is not set.
+ """
+ mocked_hook_client.run_job_flow.return_value =
RUN_JOB_FLOW_SUCCESS_RETURN
+
+ self.operator.deferrable = True
+ # wait_policy is None by default
+ result = self.operator.execute(self.mock_context)
+ assert result == JOB_FLOW_ID
+
def test_template_fields(self):
validate_template_fields(self.operator)