This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 97c4fdce71 Fix `EmrServerlessStartJobOperator` (#41103)
97c4fdce71 is described below

commit 97c4fdce71e0665997b7c3a8f78324af616c91b4
Author: Vincent <[email protected]>
AuthorDate: Mon Jul 29 16:14:11 2024 -0400

    Fix `EmrServerlessStartJobOperator` (#41103)
---
 airflow/providers/amazon/aws/operators/emr.py      | 46 +++++++++++-----------
 .../amazon/aws/operators/test_emr_serverless.py    | 21 ++++++++++
 2 files changed, 44 insertions(+), 23 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index c13c622937..fb2f5de478 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -1382,30 +1382,30 @@ class EmrServerlessStartJobOperator(BaseOperator):
 
         self.persist_links(context)
 
-        if self.deferrable:
-            self.defer(
-                trigger=EmrServerlessStartJobTrigger(
-                    application_id=self.application_id,
-                    job_id=self.job_id,
-                    waiter_delay=self.waiter_delay,
-                    waiter_max_attempts=self.waiter_max_attempts,
-                    aws_conn_id=self.aws_conn_id,
-                ),
-                method_name="execute_complete",
-                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
-            )
-
         if self.wait_for_completion:
-            waiter = self.hook.get_waiter("serverless_job_completed")
-            wait(
-                waiter=waiter,
-                waiter_max_attempts=self.waiter_max_attempts,
-                waiter_delay=self.waiter_delay,
-                args={"applicationId": self.application_id, "jobRunId": 
self.job_id},
-                failure_message="Serverless Job failed",
-                status_message="Serverless Job status is",
-                status_args=["jobRun.state", "jobRun.stateDetails"],
-            )
+            if self.deferrable:
+                self.defer(
+                    trigger=EmrServerlessStartJobTrigger(
+                        application_id=self.application_id,
+                        job_id=self.job_id,
+                        waiter_delay=self.waiter_delay,
+                        waiter_max_attempts=self.waiter_max_attempts,
+                        aws_conn_id=self.aws_conn_id,
+                    ),
+                    method_name="execute_complete",
+                    timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay),
+                )
+            else:
+                waiter = self.hook.get_waiter("serverless_job_completed")
+                wait(
+                    waiter=waiter,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    waiter_delay=self.waiter_delay,
+                    args={"applicationId": self.application_id, "jobRunId": 
self.job_id},
+                    failure_message="Serverless Job failed",
+                    status_message="Serverless Job status is",
+                    status_args=["jobRun.state", "jobRun.stateDetails"],
+                )
 
         return self.job_id
 
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py 
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index 4804f22869..12c5cc9380 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -836,6 +836,27 @@ class TestEmrServerlessStartJobOperator:
         with pytest.raises(TaskDeferred):
             operator.execute(self.mock_context)
 
+    @mock.patch.object(EmrServerlessHook, "conn")
+    def test_start_job_deferrable_without_wait_for_completion(self, mock_conn):
+        mock_conn.get_application.return_value = {"application": {"state": 
"STARTED"}}
+        mock_conn.start_job_run.return_value = {
+            "jobRunId": job_run_id,
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+            deferrable=True,
+            wait_for_completion=False,
+        )
+
+        result = operator.execute(self.mock_context)
+
+        assert result == job_run_id
+
     @mock.patch.object(EmrServerlessHook, "get_waiter")
     @mock.patch.object(EmrServerlessHook, "conn")
     def test_start_job_deferrable_app_not_started(self, mock_conn, 
mock_get_waiter):

Reply via email to