sunank200 commented on code in PR #39130:
URL: https://github.com/apache/airflow/pull/39130#discussion_r1575675452


##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
                 "gcp_conn_id": self.gcp_conn_id,
                 "impersonation_chain": self.impersonation_chain,
                 "polling_interval_seconds": self.polling_interval_seconds,
+                "delete_on_error": self.delete_on_error,
             },
         )
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
-        while True:
-            cluster = await self.get_async_hook().get_cluster(
-                project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+        """Run the trigger."""
+        try:
+            while True:
+                cluster = await self.fetch_cluster_status()
+                if self.check_cluster_state(cluster.status.state):
+                    if cluster.status.state == ClusterStatus.State.ERROR:
+                        await self.gather_diagnostics_and_maybe_delete(cluster)
+                    else:
+                        yield TriggerEvent(
+                            {
+                                "cluster_name": self.cluster_name,
+                                "cluster_state": cluster.status.state,
+                                "cluster": cluster,
+                            }
+                        )
+                    break
+                self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
+                await asyncio.sleep(self.polling_interval_seconds)
+        except asyncio.CancelledError:
+            await self.handle_cancellation()
+
+    async def fetch_cluster_status(self) -> Cluster:
+        """Fetch the cluster status."""
+        return await self.get_async_hook().get_cluster(
+            project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+        )
+
+    def check_cluster_state(self, state: ClusterStatus.State) -> bool:
+        """
+        Check if the state is error or running.
+
+        :param state: The state of the cluster.
+        """
+        return state in (ClusterStatus.State.ERROR, 
ClusterStatus.State.RUNNING)
+
+    async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster):
+        """
+        Gather diagnostics and maybe delete the cluster.
+
+        :param cluster: The cluster to gather diagnostics for.
+        """
+        self.log.info("Cluster is in ERROR state. Gathering diagnostic 
information.")
+        try:
+            operation = await self.get_async_hook().diagnose_cluster(
+                region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
             )
-            state = cluster.status.state
-            self.log.info("Dataproc cluster: %s is in state: %s", 
self.cluster_name, state)
-            if state in (
-                ClusterStatus.State.ERROR,
-                ClusterStatus.State.RUNNING,
-            ):
-                break
-            self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
-            await asyncio.sleep(self.polling_interval_seconds)
-        yield TriggerEvent({"cluster_name": self.cluster_name, 
"cluster_state": state, "cluster": cluster})
+            result = await operation.result()
+            gcs_uri = str(result.response.value)
+            self.log.info(
+                "Diagnostic information for cluster %s available at: %s", 
self.cluster_name, gcs_uri
+            )
+        except Exception as e:
+            self.log.error("Failed to diagnose cluster: %s", e)
+
+        if self.delete_on_error:
+            await self.get_async_hook().delete_cluster(
+                region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
+            )
+            return TriggerEvent(
+                {
+                    "cluster_name": self.cluster_name,
+                    "cluster_state": cluster.status.state,
+                    "cluster": None,
+                    "action": "deleted",
+                }
+            )
+        else:
+            return TriggerEvent(
+                {"cluster_name": self.cluster_name, "cluster_state": 
cluster.status.state, "cluster": cluster}
+            )
+
+    async def handle_cancellation(self) -> None:
+        """Handle the cancellation of the trigger, cleaning up resources if 
necessary."""
+        self.log.info("Cancellation requested. Deleting the cluster if 
created.")
+        try:
+            if self.delete_on_error:
+                cluster = await self.fetch_cluster_status()
+                if cluster.status.state == ClusterStatus.State.ERROR:
+                    await self.get_async_hook().async_delete_cluster(
+                        region=self.region, cluster_name=self.cluster_name, 
project_id=self.project_id
+                    )
+                    self.log.info("Deleted cluster due to ERROR state during 
cancellation.")
+                else:
+                    self.log.info("Cancellation did not require cluster 
deletion.")
+        except Exception as e:
+            self.log.error("Error during cancellation handling: %s", e)

Review Comment:
   I have added an exception



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to