This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 97c4fdce71 Fix `EmrServerlessStartJobOperator` (#41103)
97c4fdce71 is described below
commit 97c4fdce71e0665997b7c3a8f78324af616c91b4
Author: Vincent <[email protected]>
AuthorDate: Mon Jul 29 16:14:11 2024 -0400
Fix `EmrServerlessStartJobOperator` (#41103)
---
airflow/providers/amazon/aws/operators/emr.py | 46 +++++++++++-----------
.../amazon/aws/operators/test_emr_serverless.py | 21 ++++++++++
2 files changed, 44 insertions(+), 23 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/emr.py
b/airflow/providers/amazon/aws/operators/emr.py
index c13c622937..fb2f5de478 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -1382,30 +1382,30 @@ class EmrServerlessStartJobOperator(BaseOperator):
self.persist_links(context)
- if self.deferrable:
- self.defer(
- trigger=EmrServerlessStartJobTrigger(
- application_id=self.application_id,
- job_id=self.job_id,
- waiter_delay=self.waiter_delay,
- waiter_max_attempts=self.waiter_max_attempts,
- aws_conn_id=self.aws_conn_id,
- ),
- method_name="execute_complete",
- timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
- )
-
if self.wait_for_completion:
- waiter = self.hook.get_waiter("serverless_job_completed")
- wait(
- waiter=waiter,
- waiter_max_attempts=self.waiter_max_attempts,
- waiter_delay=self.waiter_delay,
- args={"applicationId": self.application_id, "jobRunId":
self.job_id},
- failure_message="Serverless Job failed",
- status_message="Serverless Job status is",
- status_args=["jobRun.state", "jobRun.stateDetails"],
- )
+ if self.deferrable:
+ self.defer(
+ trigger=EmrServerlessStartJobTrigger(
+ application_id=self.application_id,
+ job_id=self.job_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ aws_conn_id=self.aws_conn_id,
+ ),
+ method_name="execute_complete",
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ )
+ else:
+ waiter = self.hook.get_waiter("serverless_job_completed")
+ wait(
+ waiter=waiter,
+ waiter_max_attempts=self.waiter_max_attempts,
+ waiter_delay=self.waiter_delay,
+ args={"applicationId": self.application_id, "jobRunId":
self.job_id},
+ failure_message="Serverless Job failed",
+ status_message="Serverless Job status is",
+ status_args=["jobRun.state", "jobRun.stateDetails"],
+ )
return self.job_id
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 4804f22869..12c5cc9380 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -836,6 +836,27 @@ class TestEmrServerlessStartJobOperator:
with pytest.raises(TaskDeferred):
operator.execute(self.mock_context)
+ @mock.patch.object(EmrServerlessHook, "conn")
+ def test_start_job_deferrable_without_wait_for_completion(self, mock_conn):
+ mock_conn.get_application.return_value = {"application": {"state":
"STARTED"}}
+ mock_conn.start_job_run.return_value = {
+ "jobRunId": job_run_id,
+ "ResponseMetadata": {"HTTPStatusCode": 200},
+ }
+ operator = EmrServerlessStartJobOperator(
+ task_id=task_id,
+ application_id=application_id,
+ execution_role_arn=execution_role_arn,
+ job_driver=job_driver,
+ configuration_overrides=configuration_overrides,
+ deferrable=True,
+ wait_for_completion=False,
+ )
+
+ result = operator.execute(self.mock_context)
+
+ assert result == job_run_id
+
@mock.patch.object(EmrServerlessHook, "get_waiter")
@mock.patch.object(EmrServerlessHook, "conn")
def test_start_job_deferrable_app_not_started(self, mock_conn,
mock_get_waiter):