This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b777514253 Add handling state of existing Dataproc batch (#24924)
b777514253 is described below

commit b7775142530d053527b0f21f48e04b95ca8861ab
Author: Daniel van der Ende <[email protected]>
AuthorDate: Tue Jul 12 09:54:55 2022 +0200

    Add handling state of existing Dataproc batch (#24924)
    
    This change avoids Airflow marking tasks as 'Success' even if the
    existing Batch is in a 'Failed' state. We check the various states,
    and ensure that the Airflow task state reflects the actual state of
    the Dataproc Batch.
    
    Co-authored-by: Daniel van der Ende <[email protected]>
---
 .../providers/google/cloud/operators/dataproc.py   | 16 ++++++++++++
 .../google/cloud/operators/test_dataproc.py        | 30 ++++++++++++++++++++++
 2 files changed, 46 insertions(+)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 7a160d1ba9..69ba3943e8 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -2049,6 +2049,22 @@ class DataprocCreateBatchOperator(BaseOperator):
                 timeout=self.timeout,
                 metadata=self.metadata,
             )
+
+            # The existing batch may be a in a number of states other than 
'SUCCEEDED'
+            if result.state != Batch.State.SUCCEEDED:
+                if result.state == Batch.State.FAILED or result.state == 
Batch.State.CANCELLED:
+                    raise AirflowException(
+                        f"Existing Batch {self.batch_id} failed or cancelled. "
+                        f"Error: {result.state_message}"
+                    )
+                else:
+                    # Batch state is either: RUNNING, PENDING, CANCELLING, or 
UNSPECIFIED
+                    self.log.info(
+                        f"Batch {self.batch_id} is in state 
{result.state.name}."
+                        "Waiting for state change..."
+                    )
+                    result = hook.wait_for_operation(timeout=self.timeout, 
operation=result)
+
         batch_id = self.batch_id or result.name.split('/')[-1]
         DataprocLink.persist(context=context, task_instance=self, 
url=DATAPROC_BATCH_LINK, resource=batch_id)
         return Batch.to_dict(result)
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 7992f5423e..a7101ca558 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -23,6 +23,7 @@ from unittest.mock import MagicMock, Mock, call
 import pytest
 from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.retry import Retry
+from google.cloud.dataproc_v1 import Batch
 
 from airflow import AirflowException
 from airflow.exceptions import AirflowTaskTimeout
@@ -1658,6 +1659,35 @@ class TestDataprocCreateBatchOperator:
             metadata=METADATA,
         )
 
+    @mock.patch(DATAPROC_PATH.format("Batch.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_batch_failed(self, mock_hook, to_dict_mock):
+        op = DataprocCreateBatchOperator(
+            task_id=TASK_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            batch=BATCH,
+            batch_id=BATCH_ID,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
+        mock_hook.return_value.get_batch.return_value.state = 
Batch.State.FAILED
+        with pytest.raises(AirflowException):
+            op.execute(context=MagicMock())
+            mock_hook.return_value.get_batch.assert_called_once_with(
+                batch_id=BATCH_ID,
+                region=GCP_LOCATION,
+                project_id=GCP_PROJECT,
+                retry=RETRY,
+                timeout=TIMEOUT,
+                metadata=METADATA,
+            )
+
 
 class TestDataprocDeleteBatchOperator:
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))

Reply via email to