This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bea1b7f70c Improve `DataprocCreateClusterOperator` Triggers for Better
Error Handling and Resource Cleanup (#39130)
bea1b7f70c is described below
commit bea1b7f70cd08b0cdb3cf0515646374d101c8f27
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Fri Apr 26 14:20:46 2024 +0545
Improve `DataprocCreateClusterOperator` Triggers for Better Error Handling
and Resource Cleanup (#39130)
---
.../providers/google/cloud/operators/dataproc.py | 1 +
.../providers/google/cloud/triggers/dataproc.py | 92 ++++++++++++---
.../google/cloud/triggers/test_dataproc.py | 125 ++++++++++++++++++---
3 files changed, 189 insertions(+), 29 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index edbfbd3f39..e4fccfedd8 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -816,6 +816,7 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
+ delete_on_error=self.delete_on_error,
),
method_name="execute_complete",
)
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py
b/airflow/providers/google/cloud/triggers/dataproc.py
index f0aecddb4a..32b536a2ec 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -25,9 +25,10 @@ import time
from typing import Any, AsyncIterator, Sequence
from google.api_core.exceptions import NotFound
-from google.cloud.dataproc_v1 import Batch, ClusterStatus, JobStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
-from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook,
DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -43,6 +44,7 @@ class DataprocBaseTrigger(BaseTrigger):
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
polling_interval_seconds: int = 30,
+ delete_on_error: bool = True,
):
super().__init__()
self.region = region
@@ -50,6 +52,7 @@ class DataprocBaseTrigger(BaseTrigger):
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_interval_seconds = polling_interval_seconds
+ self.delete_on_error = delete_on_error
def get_async_hook(self):
return DataprocAsyncHook(
@@ -57,6 +60,16 @@ class DataprocBaseTrigger(BaseTrigger):
impersonation_chain=self.impersonation_chain,
)
+ def get_sync_hook(self):
+ # The synchronous hook is utilized to delete the cluster when a task
is cancelled.
+ # This is because the asynchronous hook deletion is not awaited when
the trigger task
+ # is cancelled. The call for deleting the cluster through the sync
hook is not a blocking
+ # call, which means it does not wait until the cluster is deleted.
+ return DataprocHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
class DataprocSubmitTrigger(DataprocBaseTrigger):
"""
@@ -140,24 +153,73 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
+ "delete_on_error": self.delete_on_error,
},
)
async def run(self) -> AsyncIterator[TriggerEvent]:
- while True:
- cluster = await self.get_async_hook().get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ try:
+ while True:
+ cluster = await self.fetch_cluster()
+ state = cluster.status.state
+ if state == ClusterStatus.State.ERROR:
+ await self.delete_when_error_occurred(cluster)
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": ClusterStatus.State.DELETING,
+ "cluster": cluster,
+ }
+ )
+ return
+ elif state == ClusterStatus.State.RUNNING:
+ yield TriggerEvent(
+ {
+ "cluster_name": self.cluster_name,
+ "cluster_state": state,
+ "cluster": cluster,
+ }
+ )
+ return
+ self.log.info("Current state is %s", state)
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ except asyncio.CancelledError:
+ try:
+ if self.delete_on_error:
+ self.log.info("Deleting cluster %s.", self.cluster_name)
+ # The synchronous hook is utilized to delete the cluster
when a task is cancelled.
+ # This is because the asynchronous hook deletion is not
awaited when the trigger task
+ # is cancelled. The call for deleting the cluster through
the sync hook is not a blocking
+ # call, which means it does not wait until the cluster is
deleted.
+ self.get_sync_hook().delete_cluster(
+ region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id
+ )
+ self.log.info("Deleted cluster %s during cancellation.",
self.cluster_name)
+ except Exception as e:
+ self.log.error("Error during cancellation handling: %s", e)
+ raise AirflowException("Error during cancellation handling:
%s", e)
+
+ async def fetch_cluster(self) -> Cluster:
+ """Fetch the cluster status."""
+ return await self.get_async_hook().get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+
+ async def delete_when_error_occurred(self, cluster: Cluster) -> None:
+ """
+ Delete the cluster on error.
+
+ :param cluster: The cluster to delete.
+ """
+ if self.delete_on_error:
+ self.log.info("Deleting cluster %s.", self.cluster_name)
+ await self.get_async_hook().delete_cluster(
+ region=self.region, cluster_name=self.cluster_name,
project_id=self.project_id
)
- state = cluster.status.state
- self.log.info("Dataproc cluster: %s is in state: %s",
self.cluster_name, state)
- if state in (
- ClusterStatus.State.ERROR,
- ClusterStatus.State.RUNNING,
- ):
- break
- self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
- await asyncio.sleep(self.polling_interval_seconds)
- yield TriggerEvent({"cluster_name": self.cluster_name,
"cluster_state": state, "cluster": cluster})
+ self.log.info("Cluster %s has been deleted.", self.cluster_name)
+ else:
+ self.log.info("Cluster %s is not deleted as delete_on_error is set
to False.", self.cluster_name)
class DataprocBatchTrigger(DataprocBaseTrigger):
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py
b/tests/providers/google/cloud/triggers/test_dataproc.py
index 45607d51b8..e310f2e0df 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -22,7 +22,7 @@ from asyncio import Future
from unittest import mock
import pytest
-from google.cloud.dataproc_v1 import Batch, ClusterStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status
@@ -70,6 +70,7 @@ def batch_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
+ delete_on_error=True,
)
return trigger
@@ -96,6 +97,7 @@ def diagnose_operation_trigger():
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
+ delete_on_error=True,
)
@@ -147,6 +149,7 @@ class TestDataprocClusterTrigger:
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
+ "delete_on_error": True,
}
@pytest.mark.asyncio
@@ -175,27 +178,37 @@ class TestDataprocClusterTrigger:
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+ @mock.patch(
+
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
+ return_value=asyncio.Future(),
+ )
+ @mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
- self, mock_hook, cluster_trigger, async_get_cluster
+ self, mock_auth, mock_delete_cluster, mock_get_cluster,
cluster_trigger, async_get_cluster, caplog
):
- mock_hook.return_value = async_get_cluster(
+ mock_credentials = mock.MagicMock()
+ mock_credentials.universe_domain = "googleapis.com"
+
+ mock_auth.return_value = (mock_credentials, "project-id")
+
+ mock_delete_cluster.return_value = asyncio.Future()
+ mock_delete_cluster.return_value.set_result(None)
+
+ mock_get_cluster.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)
- actual_event = await cluster_trigger.run().asend(None)
- await asyncio.sleep(0.5)
+ caplog.set_level(logging.INFO)
- expected_event = TriggerEvent(
- {
- "cluster_name": TEST_CLUSTER_NAME,
- "cluster_state": ClusterStatus.State.ERROR,
- "cluster": actual_event.payload["cluster"],
- }
- )
- assert expected_event == actual_event
+ trigger_event = None
+ async for event in cluster_trigger.run():
+ trigger_event = event
+
+ assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
+ assert trigger_event.payload["cluster_state"] ==
ClusterStatus.State.DELETING
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@@ -215,9 +228,93 @@ class TestDataprocClusterTrigger:
await asyncio.sleep(0.5)
assert not task.done()
- assert f"Current state is: {ClusterStatus.State.CREATING}"
+ assert f"Current state is: {ClusterStatus.State.CREATING}."
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
+ async def test_cluster_trigger_cancellation_handling(
+ self, mock_get_sync_hook, mock_get_async_hook, caplog
+ ):
+ cluster =
Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
+ mock_get_async_hook.return_value.get_cluster.return_value =
asyncio.Future()
+
mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster)
+
+ mock_delete_cluster = mock.MagicMock()
+ mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster
+
+ cluster_trigger = DataprocClusterTrigger(
+ cluster_name="cluster_name",
+ project_id="project-id",
+ region="region",
+ gcp_conn_id="google_cloud_default",
+ impersonation_chain=None,
+ polling_interval_seconds=5,
+ delete_on_error=True,
+ )
+
+ cluster_trigger_gen = cluster_trigger.run()
+
+ try:
+ await cluster_trigger_gen.__anext__()
+ await cluster_trigger_gen.aclose()
+
+ except asyncio.CancelledError:
+ # Verify that cancellation was handled as expected
+ if cluster_trigger.delete_on_error:
+ mock_get_sync_hook.assert_called_once()
+ mock_delete_cluster.assert_called_once_with(
+ region=cluster_trigger.region,
+ cluster_name=cluster_trigger.cluster_name,
+ project_id=cluster_trigger.project_id,
+ )
+ assert "Deleting cluster" in caplog.text
+ assert "Deleted cluster" in caplog.text
+ else:
+ mock_delete_cluster.assert_not_called()
+ except Exception as e:
+ pytest.fail(f"Unexpected exception raised: {e}")
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+ async def test_fetch_cluster_status(self, mock_get_cluster,
cluster_trigger, async_get_cluster):
+ mock_get_cluster.return_value = async_get_cluster(
+ status=ClusterStatus(state=ClusterStatus.State.RUNNING)
+ )
+ cluster = await cluster_trigger.fetch_cluster()
+
+ assert cluster.status.state == ClusterStatus.State.RUNNING, "The
cluster state should be RUNNING"
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
+ async def test_delete_when_error_occurred(self, mock_delete_cluster,
cluster_trigger):
+ mock_cluster = mock.MagicMock(spec=Cluster)
+ type(mock_cluster).status = mock.PropertyMock(
+ return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
+ )
+
+ mock_delete_future = asyncio.Future()
+ mock_delete_future.set_result(None)
+ mock_delete_cluster.return_value = mock_delete_future
+
+ cluster_trigger.delete_on_error = True
+
+ await cluster_trigger.delete_when_error_occurred(mock_cluster)
+
+ mock_delete_cluster.assert_called_once_with(
+ region=cluster_trigger.region,
+ cluster_name=cluster_trigger.cluster_name,
+ project_id=cluster_trigger.project_id,
+ )
+
+ mock_delete_cluster.reset_mock()
+ cluster_trigger.delete_on_error = False
+
+ await cluster_trigger.delete_when_error_occurred(mock_cluster)
+
+ mock_delete_cluster.assert_not_called()
+
@pytest.mark.db_test
class TestDataprocBatchTrigger: