This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 66294de4e0 Fix failure state in waiter call for 
EmrServerlessStartJobOperator. (#26853)
66294de4e0 is described below

commit 66294de4e081e1c65731296c66824ae847bdca7d
Author: syedahsn <[email protected]>
AuthorDate: Mon Oct 10 00:09:26 2022 -0700

    Fix failure state in waiter call for EmrServerlessStartJobOperator. (#26853)
    
    Add additional tests to cover previously untested cases
---
 airflow/providers/amazon/aws/operators/emr.py      |   2 +-
 .../amazon/aws/operators/test_emr_serverless.py    | 137 +++++++++++++++++++--
 2 files changed, 128 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 93c8adb840..1413e8f57e 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -643,7 +643,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
                 get_state_args={'applicationId': self.application_id},
                 parse_response=['application', 'state'],
                 desired_state={'STARTED'},
-                failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
+                failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
                 object_type='application',
                 action='started',
             )
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py 
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index a175bc8d10..44b98d453d 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -44,14 +44,16 @@ application_id_delete_operator = 
'test_emr_serverless_delete_application_operato
 
 
 class TestEmrServerlessCreateApplicationOperator:
-    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
-    def test_execute_successfully_with_wait_for_completion(self, mock_conn, 
mock_waiter):
-        mock_waiter.return_value = True
+    def test_execute_successfully_with_wait_for_completion(self, mock_conn):
         mock_conn.create_application.return_value = {
             "applicationId": application_id,
             "ResponseMetadata": {"HTTPStatusCode": 200},
         }
+        mock_conn.get_application.side_effect = [
+            {'application': {'state': 'CREATED'}},
+            {'application': {'state': 'STARTED'}},
+        ]
 
         operator = EmrServerlessCreateApplicationOperator(
             task_id=task_id,
@@ -70,9 +72,8 @@ class TestEmrServerlessCreateApplicationOperator:
             **config,
         )
         
mock_conn.start_application.assert_called_once_with(applicationId=application_id)
-
-        assert mock_waiter.call_count == 2
         assert id == application_id
+        mock_conn.get_application.call_count == 2
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
@@ -107,7 +108,7 @@ class TestEmrServerlessCreateApplicationOperator:
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
-    def test_failed_create_application(self, mock_conn, mock_waiter):
+    def test_failed_create_application_request(self, mock_conn, mock_waiter):
         mock_waiter.return_value = True
         mock_conn.create_application.return_value = {
             "applicationId": application_id,
@@ -134,6 +135,67 @@ class TestEmrServerlessCreateApplicationOperator:
             **config,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+    def test_failed_create_application(self, mock_conn):
+        mock_conn.create_application.return_value = {
+            "applicationId": application_id,
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        mock_conn.get_application.return_value = {'application': {'state': 
'TERMINATED'}}
+
+        operator = EmrServerlessCreateApplicationOperator(
+            task_id=task_id,
+            release_label=release_label,
+            job_type=job_type,
+            client_request_token=client_request_token,
+            config=config,
+        )
+
+        with pytest.raises(AirflowException) as ex_message:
+            operator.execute(None)
+
+        assert "Application reached failure state" in str(ex_message.value)
+
+        mock_conn.create_application.assert_called_once_with(
+            clientToken=client_request_token,
+            releaseLabel=release_label,
+            type=job_type,
+            **config,
+        )
+        
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+    def test_failed_start_application(self, mock_conn):
+        mock_conn.create_application.return_value = {
+            "applicationId": application_id,
+            "ResponseMetadata": {"HTTPStatusCode": 200},
+        }
+        mock_conn.get_application.side_effect = [
+            {'application': {'state': 'CREATED'}},
+            {'application': {'state': 'TERMINATED'}},
+        ]
+
+        operator = EmrServerlessCreateApplicationOperator(
+            task_id=task_id,
+            release_label=release_label,
+            job_type=job_type,
+            client_request_token=client_request_token,
+            config=config,
+        )
+
+        with pytest.raises(AirflowException) as ex_message:
+            operator.execute(None)
+
+        assert "Application reached failure state" in str(ex_message.value)
+
+        mock_conn.create_application.assert_called_once_with(
+            clientToken=client_request_token,
+            releaseLabel=release_label,
+            type=job_type,
+            **config,
+        )
+        mock_conn.get_application.call_count == 2
+
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
     def test_no_client_request_token(self, mock_conn, mock_waiter):
@@ -187,15 +249,14 @@ class TestEmrServerlessCreateApplicationOperator:
 
 
 class TestEmrServerlessStartJobOperator:
-    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
-    def test_job_run_app_started(self, mock_conn, mock_waiter):
-        mock_waiter.return_value = True
+    def test_job_run_app_started(self, mock_conn):
         mock_conn.get_application.return_value = {"application": {"state": 
"STARTED"}}
         mock_conn.start_job_run.return_value = {
             'jobRunId': job_run_id,
             'ResponseMetadata': {'HTTPStatusCode': 200},
         }
+        mock_conn.get_job_run.return_value = {'jobRun': {'state': 'SUCCESS'}}
 
         operator = EmrServerlessStartJobOperator(
             task_id=task_id,
@@ -210,7 +271,6 @@ class TestEmrServerlessStartJobOperator:
 
         assert operator.wait_for_completion is True
         
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
-        mock_waiter.assert_called_once()
         assert id == job_run_id
         mock_conn.start_job_run.assert_called_once_with(
             clientToken=client_request_token,
@@ -219,6 +279,7 @@ class TestEmrServerlessStartJobOperator:
             jobDriver=job_driver,
             configurationOverrides=configuration_overrides,
         )
+        
mock_conn.get_job_run.assert_called_once_with(applicationId=application_id, 
jobRunId=job_run_id)
 
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
     def test_job_run_job_failed(self, mock_conn):
@@ -285,6 +346,32 @@ class TestEmrServerlessStartJobOperator:
             configurationOverrides=configuration_overrides,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+    def test_job_run_app_not_started_app_failed(self, mock_conn):
+        mock_conn.get_application.side_effect = [
+            {"application": {"state": "CREATING"}},
+            {"application": {"state": "TERMINATED"}},
+        ]
+        mock_conn.start_job_run.return_value = {
+            'jobRunId': job_run_id,
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            client_request_token=client_request_token,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+        )
+        with pytest.raises(AirflowException) as ex_message:
+            operator.execute(None)
+        assert "Application reached failure state" in str(ex_message.value)
+        assert operator.wait_for_completion is True
+        mock_conn.get_application.call_count == 2
+        mock_conn.assert_not_called()
+
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
     def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, 
mock_waiter):
@@ -381,6 +468,36 @@ class TestEmrServerlessStartJobOperator:
             configurationOverrides=configuration_overrides,
         )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.conn")
+    def test_start_job_run_fail_on_wait_for_completion(self, mock_conn):
+        mock_conn.get_application.return_value = {"application": {"state": 
"CREATED"}}
+        mock_conn.start_job_run.return_value = {
+            'jobRunId': job_run_id,
+            'ResponseMetadata': {'HTTPStatusCode': 200},
+        }
+        mock_conn.get_job_run.return_value = {'jobRun': {'state': 'FAILED'}}
+
+        operator = EmrServerlessStartJobOperator(
+            task_id=task_id,
+            client_request_token=client_request_token,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=job_driver,
+            configuration_overrides=configuration_overrides,
+        )
+        with pytest.raises(AirflowException) as ex_message:
+            operator.execute(None)
+
+        assert "Job reached failure state" in str(ex_message.value)
+        mock_conn.get_application.call_count == 2
+        mock_conn.start_job_run.assert_called_once_with(
+            clientToken=client_request_token,
+            applicationId=application_id,
+            executionRoleArn=execution_role_arn,
+            jobDriver=job_driver,
+            configurationOverrides=configuration_overrides,
+        )
+
 
 class TestEmrServerlessDeleteOperator:
     
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrServerlessHook.waiter")

Reply via email to