Lee-W commented on code in PR #39130:
URL: https://github.com/apache/airflow/pull/39130#discussion_r1575713841
##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ """Run the trigger."""
+ try:
+ while True:
+ cluster = await self.fetch_cluster_status()
+ if self.check_cluster_state(cluster.status.state):
+ if cluster.status.state == ClusterStatus.State.ERROR:
+ await self.gather_diagnostics_and_maybe_delete(cluster)
+ else:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": cluster,
+ }
+ )
Review Comment:
```suggestion
)
return
```
##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ """Run the trigger."""
+ try:
+ while True:
+ cluster = await self.fetch_cluster_status()
+ if self.check_cluster_state(cluster.status.state):
+ if cluster.status.state == ClusterStatus.State.ERROR:
+ await self.gather_diagnostics_and_maybe_delete(cluster)
+ else:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": cluster,
+ }
+ )
+ break
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ except asyncio.CancelledError:
+ await self.handle_cancellation()
+
+ async def fetch_cluster_status(self) -> Cluster:
Review Comment:
Looks like we're fetching cluster instead of cluster_status. Or is there
anything I missed?
##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ """Run the trigger."""
+ try:
+ while True:
+ cluster = await self.fetch_cluster_status()
+ if self.check_cluster_state(cluster.status.state):
+ if cluster.status.state == ClusterStatus.State.ERROR:
+ await self.gather_diagnostics_and_maybe_delete(cluster)
+ else:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": cluster,
+ }
+ )
+ break
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ except asyncio.CancelledError:
+ await self.handle_cancellation()
+
+ async def fetch_cluster_status(self) -> Cluster:
+ """Fetch the cluster status."""
+ return await self.get_async_hook().get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+
+ def check_cluster_state(self, state: ClusterStatus.State) -> bool:
Review Comment:
Looks like we can make it a staticmethod
##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ """Run the trigger."""
+ try:
+ while True:
+ cluster = await self.fetch_cluster_status()
+ if self.check_cluster_state(cluster.status.state):
+ if cluster.status.state == ClusterStatus.State.ERROR:
+ await self.gather_diagnostics_and_maybe_delete(cluster)
+ else:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": cluster,
+ }
+ )
+ break
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ except asyncio.CancelledError:
+ await self.handle_cancellation()
+
+ async def fetch_cluster_status(self) -> Cluster:
+ """Fetch the cluster status."""
+ return await self.get_async_hook().get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+
+ def check_cluster_state(self, state: ClusterStatus.State) -> bool:
+ """
+ Check if the state is error or running.
+
+ :param state: The state of the cluster.
+ """
+ return state in (ClusterStatus.State.ERROR,
ClusterStatus.State.RUNNING)
+
+ async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster):
Review Comment:
```suggestion
async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster) ->
TriggerEvent:
```
Is there a way to rename the `maybe`? maybe `delete_if_...`?
##########
airflow/providers/google/cloud/triggers/dataproc.py:
##########
@@ -140,24 +142,97 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ """Run the trigger."""
+ try:
+ while True:
+ cluster = await self.fetch_cluster_status()
+ if self.check_cluster_state(cluster.status.state):
+ if cluster.status.state == ClusterStatus.State.ERROR:
+ await self.gather_diagnostics_and_maybe_delete(cluster)
+ else:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": cluster,
+ }
+ )
+ break
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ except asyncio.CancelledError:
+ await self.handle_cancellation()
+
+ async def fetch_cluster_status(self) -> Cluster:
+ """Fetch the cluster status."""
+ return await self.get_async_hook().get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+
+ def check_cluster_state(self, state: ClusterStatus.State) -> bool:
+ """
+ Check if the state is error or running.
+
+ :param state: The state of the cluster.
+ """
+ return state in (ClusterStatus.State.ERROR,
ClusterStatus.State.RUNNING)
+
+ async def gather_diagnostics_and_maybe_delete(self, cluster: Cluster):
+ """
+ Gather diagnostics and maybe delete the cluster.
+
+ :param cluster: The cluster to gather diagnostics for.
+ """
+ self.log.info("Cluster is in ERROR state. Gathering diagnostic
information.")
+ try:
+ operation = await self.get_async_hook().diagnose_cluster(
+ region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id
)
- state = cluster.status.state
- self.log.info("Dataproc cluster: %s is in state: %s",
self.cluster_name, state)
- if state in (
- ClusterStatus.State.ERROR,
- ClusterStatus.State.RUNNING,
- ):
- break
- self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
- await asyncio.sleep(self.polling_interval_seconds)
- yield TriggerEvent({"cluster_name": self.cluster_name,
"cluster_state": state, "cluster": cluster})
+ result = await operation.result()
+ gcs_uri = str(result.response.value)
+ self.log.info(
+ "Diagnostic information for cluster %s available at: %s",
self.cluster_name, gcs_uri
+ )
+ except Exception as e:
+ self.log.error("Failed to diagnose cluster: %s", e)
+
+ if self.delete_on_error:
+ await self.get_async_hook().delete_cluster(
+ region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id
+ )
+ return TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": cluster.status.state,
+ "cluster": None,
+ "action": "deleted",
+ }
+ )
+ else:
+ return TriggerEvent(
+ {"cluster_name": self.cluster_name, "cluster_state":
cluster.status.state, "cluster": cluster}
+ )
+
+ async def handle_cancellation(self) -> None:
Review Comment:
`handle_cancellation` seems to be a board idea and it's not easy to
understand what's handled. I'm thinking of something like the following in the
run method. WDYT?
```python
except asyncio.CancelledError:
try:
if self.delete_on_error:
await cleanup_cluster()
except Exception:
......
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]