This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bea1b7f70c Improve `DataprocCreateClusterOperator` Triggers for Better 
Error Handling and Resource Cleanup (#39130)
bea1b7f70c is described below

commit bea1b7f70cd08b0cdb3cf0515646374d101c8f27
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Fri Apr 26 14:20:46 2024 +0545

    Improve `DataprocCreateClusterOperator` Triggers for Better Error Handling 
and Resource Cleanup (#39130)
---
 .../providers/google/cloud/operators/dataproc.py   |   1 +
 .../providers/google/cloud/triggers/dataproc.py    |  92 ++++++++++++---
 .../google/cloud/triggers/test_dataproc.py         | 125 ++++++++++++++++++---
 3 files changed, 189 insertions(+), 29 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index edbfbd3f39..e4fccfedd8 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -816,6 +816,7 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
                             gcp_conn_id=self.gcp_conn_id,
                             impersonation_chain=self.impersonation_chain,
                             
polling_interval_seconds=self.polling_interval_seconds,
+                            delete_on_error=self.delete_on_error,
                         ),
                         method_name="execute_complete",
                     )
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py 
b/airflow/providers/google/cloud/triggers/dataproc.py
index f0aecddb4a..32b536a2ec 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -25,9 +25,10 @@ import time
 from typing import Any, AsyncIterator, Sequence
 
 from google.api_core.exceptions import NotFound
-from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
 
-from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, 
DataprocHook
 from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
 from airflow.providers.google.common.hooks.base_google import 
PROVIDE_PROJECT_ID
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -43,6 +44,7 @@ class DataprocBaseTrigger(BaseTrigger):
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         polling_interval_seconds: int = 30,
+        delete_on_error: bool = True,
     ):
         super().__init__()
         self.region = region
@@ -50,6 +52,7 @@ class DataprocBaseTrigger(BaseTrigger):
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.polling_interval_seconds = polling_interval_seconds
+        self.delete_on_error = delete_on_error
 
     def get_async_hook(self):
         return DataprocAsyncHook(
@@ -57,6 +60,16 @@ class DataprocBaseTrigger(BaseTrigger):
             impersonation_chain=self.impersonation_chain,
         )
 
+    def get_sync_hook(self):
+        # The synchronous hook is utilized to delete the cluster when a task 
is cancelled.
+        # This is because the asynchronous hook deletion is not awaited when 
the trigger task
+        # is cancelled. The call for deleting the cluster through the sync 
hook is not a blocking
+        # call, which means it does not wait until the cluster is deleted.
+        return DataprocHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
 
 class DataprocSubmitTrigger(DataprocBaseTrigger):
     """
@@ -140,24 +153,73 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
                 "gcp_conn_id": self.gcp_conn_id,
                 "impersonation_chain": self.impersonation_chain,
                 "polling_interval_seconds": self.polling_interval_seconds,
+                "delete_on_error": self.delete_on_error,
             },
         )
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
-        while True:
-            cluster = await self.get_async_hook().get_cluster(
-                project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+        try:
+            while True:
+                cluster = await self.fetch_cluster()
+                state = cluster.status.state
+                if state == ClusterStatus.State.ERROR:
+                    await self.delete_when_error_occurred(cluster)
+                    yield TriggerEvent(
+                        {
+                            "cluster_name": self.cluster_name,
+                            "cluster_state": ClusterStatus.State.DELETING,
+                            "cluster": cluster,
+                        }
+                    )
+                    return
+                elif state == ClusterStatus.State.RUNNING:
+                    yield TriggerEvent(
+                        {
+                            "cluster_name": self.cluster_name,
+                            "cluster_state": state,
+                            "cluster": cluster,
+                        }
+                    )
+                    return
+                self.log.info("Current state is %s", state)
+                self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
+                await asyncio.sleep(self.polling_interval_seconds)
+        except asyncio.CancelledError:
+            try:
+                if self.delete_on_error:
+                    self.log.info("Deleting cluster %s.", self.cluster_name)
+                    # The synchronous hook is utilized to delete the cluster 
when a task is cancelled.
+                    # This is because the asynchronous hook deletion is not 
awaited when the trigger task
+                    # is cancelled. The call for deleting the cluster through 
the sync hook is not a blocking
+                    # call, which means it does not wait until the cluster is 
deleted.
+                    self.get_sync_hook().delete_cluster(
+                        region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
+                    )
+                    self.log.info("Deleted cluster %s during cancellation.", 
self.cluster_name)
+            except Exception as e:
+                self.log.error("Error during cancellation handling: %s", e)
+                raise AirflowException("Error during cancellation handling: 
%s", e)
+
+    async def fetch_cluster(self) -> Cluster:
+        """Fetch the cluster status."""
+        return await self.get_async_hook().get_cluster(
+            project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+        )
+
+    async def delete_when_error_occurred(self, cluster: Cluster) -> None:
+        """
+        Delete the cluster on error.
+
+        :param cluster: The cluster to delete.
+        """
+        if self.delete_on_error:
+            self.log.info("Deleting cluster %s.", self.cluster_name)
+            await self.get_async_hook().delete_cluster(
+                region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
             )
-            state = cluster.status.state
-            self.log.info("Dataproc cluster: %s is in state: %s", 
self.cluster_name, state)
-            if state in (
-                ClusterStatus.State.ERROR,
-                ClusterStatus.State.RUNNING,
-            ):
-                break
-            self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
-            await asyncio.sleep(self.polling_interval_seconds)
-        yield TriggerEvent({"cluster_name": self.cluster_name, 
"cluster_state": state, "cluster": cluster})
+            self.log.info("Cluster %s has been deleted.", self.cluster_name)
+        else:
+            self.log.info("Cluster %s is not deleted as delete_on_error is set 
to False.", self.cluster_name)
 
 
 class DataprocBatchTrigger(DataprocBaseTrigger):
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py 
b/tests/providers/google/cloud/triggers/test_dataproc.py
index 45607d51b8..e310f2e0df 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -22,7 +22,7 @@ from asyncio import Future
 from unittest import mock
 
 import pytest
-from google.cloud.dataproc_v1 import Batch, ClusterStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus
 from google.protobuf.any_pb2 import Any
 from google.rpc.status_pb2 import Status
 
@@ -70,6 +70,7 @@ def batch_trigger():
         gcp_conn_id=TEST_GCP_CONN_ID,
         impersonation_chain=None,
         polling_interval_seconds=TEST_POLL_INTERVAL,
+        delete_on_error=True,
     )
     return trigger
 
@@ -96,6 +97,7 @@ def diagnose_operation_trigger():
         gcp_conn_id=TEST_GCP_CONN_ID,
         impersonation_chain=None,
         polling_interval_seconds=TEST_POLL_INTERVAL,
+        delete_on_error=True,
     )
 
 
@@ -147,6 +149,7 @@ class TestDataprocClusterTrigger:
             "gcp_conn_id": TEST_GCP_CONN_ID,
             "impersonation_chain": None,
             "polling_interval_seconds": TEST_POLL_INTERVAL,
+            "delete_on_error": True,
         }
 
     @pytest.mark.asyncio
@@ -175,27 +178,37 @@ class TestDataprocClusterTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+    @mock.patch(
+        
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
+        return_value=asyncio.Future(),
+    )
+    @mock.patch("google.auth.default")
     async def test_async_cluster_trigger_run_returns_error_event(
-        self, mock_hook, cluster_trigger, async_get_cluster
+        self, mock_auth, mock_delete_cluster, mock_get_cluster, 
cluster_trigger, async_get_cluster, caplog
     ):
-        mock_hook.return_value = async_get_cluster(
+        mock_credentials = mock.MagicMock()
+        mock_credentials.universe_domain = "googleapis.com"
+
+        mock_auth.return_value = (mock_credentials, "project-id")
+
+        mock_delete_cluster.return_value = asyncio.Future()
+        mock_delete_cluster.return_value.set_result(None)
+
+        mock_get_cluster.return_value = async_get_cluster(
             project_id=TEST_PROJECT_ID,
             region=TEST_REGION,
             cluster_name=TEST_CLUSTER_NAME,
             status=ClusterStatus(state=ClusterStatus.State.ERROR),
         )
 
-        actual_event = await cluster_trigger.run().asend(None)
-        await asyncio.sleep(0.5)
+        caplog.set_level(logging.INFO)
 
-        expected_event = TriggerEvent(
-            {
-                "cluster_name": TEST_CLUSTER_NAME,
-                "cluster_state": ClusterStatus.State.ERROR,
-                "cluster": actual_event.payload["cluster"],
-            }
-        )
-        assert expected_event == actual_event
+        trigger_event = None
+        async for event in cluster_trigger.run():
+            trigger_event = event
+
+        assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
+        assert trigger_event.payload["cluster_state"] == 
ClusterStatus.State.DELETING
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@@ -215,9 +228,93 @@ class TestDataprocClusterTrigger:
         await asyncio.sleep(0.5)
 
         assert not task.done()
-        assert f"Current state is: {ClusterStatus.State.CREATING}"
+        assert f"Current state is: {ClusterStatus.State.CREATING}."
         assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
 
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
+    async def test_cluster_trigger_cancellation_handling(
+        self, mock_get_sync_hook, mock_get_async_hook, caplog
+    ):
+        cluster = 
Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
+        mock_get_async_hook.return_value.get_cluster.return_value = 
asyncio.Future()
+        
mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster)
+
+        mock_delete_cluster = mock.MagicMock()
+        mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster
+
+        cluster_trigger = DataprocClusterTrigger(
+            cluster_name="cluster_name",
+            project_id="project-id",
+            region="region",
+            gcp_conn_id="google_cloud_default",
+            impersonation_chain=None,
+            polling_interval_seconds=5,
+            delete_on_error=True,
+        )
+
+        cluster_trigger_gen = cluster_trigger.run()
+
+        try:
+            await cluster_trigger_gen.__anext__()
+            await cluster_trigger_gen.aclose()
+
+        except asyncio.CancelledError:
+            # Verify that cancellation was handled as expected
+            if cluster_trigger.delete_on_error:
+                mock_get_sync_hook.assert_called_once()
+                mock_delete_cluster.assert_called_once_with(
+                    region=cluster_trigger.region,
+                    cluster_name=cluster_trigger.cluster_name,
+                    project_id=cluster_trigger.project_id,
+                )
+                assert "Deleting cluster" in caplog.text
+                assert "Deleted cluster" in caplog.text
+            else:
+                mock_delete_cluster.assert_not_called()
+        except Exception as e:
+            pytest.fail(f"Unexpected exception raised: {e}")
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+    async def test_fetch_cluster_status(self, mock_get_cluster, 
cluster_trigger, async_get_cluster):
+        mock_get_cluster.return_value = async_get_cluster(
+            status=ClusterStatus(state=ClusterStatus.State.RUNNING)
+        )
+        cluster = await cluster_trigger.fetch_cluster()
+
+        assert cluster.status.state == ClusterStatus.State.RUNNING, "The 
cluster state should be RUNNING"
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
+    async def test_delete_when_error_occurred(self, mock_delete_cluster, 
cluster_trigger):
+        mock_cluster = mock.MagicMock(spec=Cluster)
+        type(mock_cluster).status = mock.PropertyMock(
+            return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
+        )
+
+        mock_delete_future = asyncio.Future()
+        mock_delete_future.set_result(None)
+        mock_delete_cluster.return_value = mock_delete_future
+
+        cluster_trigger.delete_on_error = True
+
+        await cluster_trigger.delete_when_error_occurred(mock_cluster)
+
+        mock_delete_cluster.assert_called_once_with(
+            region=cluster_trigger.region,
+            cluster_name=cluster_trigger.cluster_name,
+            project_id=cluster_trigger.project_id,
+        )
+
+        mock_delete_cluster.reset_mock()
+        cluster_trigger.delete_on_error = False
+
+        await cluster_trigger.delete_when_error_occurred(mock_cluster)
+
+        mock_delete_cluster.assert_not_called()
+
 
 @pytest.mark.db_test
 class TestDataprocBatchTrigger:

Reply via email to