This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e07a42e69d Check cluster state before defer Dataproc operators to
trigger (#36892)
e07a42e69d is described below
commit e07a42e69d1ab472c4da991fca5782990607ebe0
Author: Wei Lee <[email protected]>
AuthorDate: Mon Jan 22 14:32:00 2024 +0800
Check cluster state before defer Dataproc operators to trigger (#36892)
While operating a data proc cluster in deferrable mode, the condition might
already be met (created, deleted, updated) before we defer the task into the
trigger. This PR intends to check thecluster status before deferring the task
to trigger.
---------
Co-authored-by: Pankaj Koti <[email protected]>
---
.../providers/google/cloud/operators/dataproc.py | 63 ++++++---
.../google/cloud/operators/test_dataproc.py | 146 ++++++++++++++++++++-
2 files changed, 185 insertions(+), 24 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 306e0dc03d..b14121139d 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -721,6 +721,7 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
def execute(self, context: Context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
+
# Save data required to display extra link no matter what the cluster
status will be
project_id = self.project_id or hook.project_id
if project_id:
@@ -731,6 +732,7 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
project_id=project_id,
region=self.region,
)
+
try:
# First try to create a new cluster
operation = self._create_cluster(hook)
@@ -741,17 +743,24 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
self.log.info("Cluster created.")
return Cluster.to_dict(cluster)
else:
- self.defer(
- trigger=DataprocClusterTrigger(
- cluster_name=self.cluster_name,
- project_id=self.project_id,
- region=self.region,
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- polling_interval_seconds=self.polling_interval_seconds,
- ),
- method_name="execute_complete",
+ cluster = hook.get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
)
+ if cluster.status.state == cluster.status.State.RUNNING:
+ self.log.info("Cluster created.")
+ return Cluster.to_dict(cluster)
+ else:
+ self.defer(
+ trigger=DataprocClusterTrigger(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+
polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
except AlreadyExists:
if not self.use_if_exists:
raise
@@ -1016,6 +1025,16 @@ class
DataprocDeleteClusterOperator(GoogleCloudBaseOperator):
hook.wait_for_operation(timeout=self.timeout,
result_retry=self.retry, operation=operation)
self.log.info("Cluster deleted.")
else:
+ try:
+ hook.get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+ except NotFound:
+ self.log.info("Cluster deleted.")
+ return
+ except Exception as e:
+ raise AirflowException(str(e))
+
end_time: float = time.time() + self.timeout
self.defer(
trigger=DataprocDeleteClusterTrigger(
@@ -2480,17 +2499,21 @@ class
DataprocUpdateClusterOperator(GoogleCloudBaseOperator):
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout,
result_retry=self.retry, operation=operation)
else:
- self.defer(
- trigger=DataprocClusterTrigger(
- cluster_name=self.cluster_name,
- project_id=self.project_id,
- region=self.region,
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- polling_interval_seconds=self.polling_interval_seconds,
- ),
- method_name="execute_complete",
+ cluster = hook.get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
)
+ if cluster.status.state != cluster.status.State.RUNNING:
+ self.defer(
+ trigger=DataprocClusterTrigger(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
self.log.info("Updated %s cluster.", self.cluster_name)
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 59a9c1008c..00f45ca8b3 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -23,7 +23,8 @@ from unittest.mock import MagicMock, Mock, call
import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
-from google.cloud.dataproc_v1 import Batch, JobStatus
+from google.cloud import dataproc
+from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
from airflow.exceptions import (
AirflowException,
@@ -579,7 +580,7 @@ class TestsClusterGenerator:
assert CONFIG_WITH_FLEX_MIG == cluster
-class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
+class TestDataprocCreateClusterOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
op = DataprocCreateClusterOperator(
@@ -883,6 +884,54 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+ @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_create_execute_call_finished_before_defer(self,
mock_trigger_hook, mock_hook, mock_defer):
+ cluster = Cluster(
+ cluster_name="test_cluster",
+
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
+ )
+ mock_hook.return_value.create_cluster.return_value = cluster
+ mock_hook.return_value.get_cluster.return_value = cluster
+ operator = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
+ delete_on_error=True,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ deferrable=True,
+ )
+
+ operator.execute(mock.MagicMock())
+ assert not mock_defer.called
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.create_cluster.assert_called_once_with(
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_config=CONFIG,
+ request_id=None,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
+ virtual_cluster_config=None,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_operation.assert_not_called()
+
@pytest.mark.db_test
@pytest.mark.need_serialized_dag
@@ -1100,6 +1149,47 @@ class TestDataprocClusterDeleteOperator:
assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+ @mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_create_execute_call_finished_before_defer(self,
mock_trigger_hook, mock_hook, mock_defer):
+ mock_hook.return_value.create_cluster.return_value = None
+ mock_hook.return_value.get_cluster.side_effect = NotFound("test")
+ operator = DataprocDeleteClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ request_id=REQUEST_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+
+ operator.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.delete_cluster.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_REGION,
+ cluster_name=CLUSTER_NAME,
+ cluster_uuid=None,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+ mock_hook.return_value.wait_for_operation.assert_not_called()
+ assert not mock_defer.called
+
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1240,8 +1330,8 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
-
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
-
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
+ @mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job"))
def test_dataproc_operator_execute_async_done_before_defer(self,
mock_submit_job, mock_defer, mock_hook):
mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
job_status = mock_hook.return_value.get_job.return_value.status
@@ -1498,6 +1588,54 @@ class
TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+ @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_create_execute_call_finished_before_defer(self,
mock_trigger_hook, mock_hook, mock_defer):
+ cluster = Cluster(
+ cluster_name="test_cluster",
+
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
+ )
+ mock_hook.return_value.update_cluster.return_value = cluster
+ mock_hook.return_value.get_cluster.return_value = cluster
+ operator = DataprocUpdateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ cluster_name=CLUSTER_NAME,
+ cluster=CLUSTER,
+ update_mask=UPDATE_MASK,
+ request_id=REQUEST_ID,
+ graceful_decommission_timeout={"graceful_decommission_timeout":
"600s"},
+ project_id=GCP_PROJECT,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+
+ operator.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.update_cluster.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_REGION,
+ cluster_name=CLUSTER_NAME,
+ cluster=CLUSTER,
+ update_mask=UPDATE_MASK,
+ request_id=REQUEST_ID,
+ graceful_decommission_timeout={"graceful_decommission_timeout":
"600s"},
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_operation.assert_not_called()
+ assert not mock_defer.called
+
@pytest.mark.db_test
@pytest.mark.need_serialized_dag