This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fd80130e2 Add deferrable mode to DataprocCreateClusterOperator and 
DataprocUpdateClusterOperator (#28529)
9fd80130e2 is described below

commit 9fd80130e2351c7ec31bbeb6c10f6b11708b318b
Author: Beata Kossakowska <[email protected]>
AuthorDate: Wed Jan 25 23:20:22 2023 +0100

    Add deferrable mode to DataprocCreateClusterOperator and 
DataprocUpdateClusterOperator (#28529)
    
    Co-authored-by: Beata Kossakowska <[email protected]>
---
 .../providers/google/cloud/operators/dataproc.py   |  96 +++++++++++++--
 .../providers/google/cloud/triggers/dataproc.py    |  68 ++++++++++-
 .../operators/cloud/dataproc.rst                   |  16 +++
 .../google/cloud/operators/test_dataproc.py        | 100 ++++++++++++++-
 .../google/cloud/triggers/test_dataproc.py         | 136 +++++++++++++++++++++
 .../example_dataproc_cluster_deferrable.py         | 121 ++++++++++++++++++
 6 files changed, 515 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index d2899037ae..316a15981f 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -26,13 +26,13 @@ import time
 import uuid
 import warnings
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from google.api_core import operation  # type: ignore
 from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
 from google.api_core.retry import Retry, exponential_sleep_generator
-from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
 from google.protobuf.duration_pb2 import Duration
 from google.protobuf.field_mask_pb2 import FieldMask
 
@@ -50,7 +50,7 @@ from airflow.providers.google.cloud.links.dataproc import (
     DataprocLink,
     DataprocListLink,
 )
-from airflow.providers.google.cloud.triggers.dataproc import 
DataprocBaseTrigger
+from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger, DataprocSubmitTrigger
 from airflow.utils import timezone
 
 if TYPE_CHECKING:
@@ -438,6 +438,8 @@ class DataprocCreateClusterOperator(BaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode.
+    :param polling_interval_seconds: Time (seconds) to wait between calls to 
check the run status.
     """
 
     template_fields: Sequence[str] = (
@@ -470,6 +472,8 @@ class DataprocCreateClusterOperator(BaseOperator):
         metadata: Sequence[tuple[str, str]] = (),
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = False,
+        polling_interval_seconds: int = 10,
         **kwargs,
     ) -> None:
 
@@ -502,7 +506,8 @@ class DataprocCreateClusterOperator(BaseOperator):
                     del kwargs[arg]
 
         super().__init__(**kwargs)
-
+        if deferrable and polling_interval_seconds <= 0:
+            raise ValueError("Invalid value for polling_interval_seconds. 
Expected value greater than 0")
         self.cluster_config = cluster_config
         self.cluster_name = cluster_name
         self.labels = labels
@@ -517,9 +522,11 @@ class DataprocCreateClusterOperator(BaseOperator):
         self.use_if_exists = use_if_exists
         self.impersonation_chain = impersonation_chain
         self.virtual_cluster_config = virtual_cluster_config
+        self.deferrable = deferrable
+        self.polling_interval_seconds = polling_interval_seconds
 
     def _create_cluster(self, hook: DataprocHook):
-        operation = hook.create_cluster(
+        return hook.create_cluster(
             project_id=self.project_id,
             region=self.region,
             cluster_name=self.cluster_name,
@@ -531,9 +538,6 @@ class DataprocCreateClusterOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        cluster = operation.result()
-        self.log.info("Cluster created.")
-        return cluster
 
     def _delete_cluster(self, hook):
         self.log.info("Deleting the cluster")
@@ -596,7 +600,25 @@ class DataprocCreateClusterOperator(BaseOperator):
         )
         try:
             # First try to create a new cluster
-            cluster = self._create_cluster(hook)
+            operation = self._create_cluster(hook)
+            if not self.deferrable:
+                cluster = hook.wait_for_operation(
+                    timeout=self.timeout, result_retry=self.retry, 
operation=operation
+                )
+                self.log.info("Cluster created.")
+                return Cluster.to_dict(cluster)
+            else:
+                self.defer(
+                    trigger=DataprocClusterTrigger(
+                        cluster_name=self.cluster_name,
+                        project_id=self.project_id,
+                        region=self.region,
+                        gcp_conn_id=self.gcp_conn_id,
+                        impersonation_chain=self.impersonation_chain,
+                        polling_interval_seconds=self.polling_interval_seconds,
+                    ),
+                    method_name="execute_complete",
+                )
         except AlreadyExists:
             if not self.use_if_exists:
                 raise
@@ -618,6 +640,21 @@ class DataprocCreateClusterOperator(BaseOperator):
 
         return Cluster.to_dict(cluster)
 
+    def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        cluster_state = event["cluster_state"]
+        cluster_name = event["cluster_name"]
+
+        if cluster_state == ClusterStatus.State.ERROR:
+            raise AirflowException(f"Cluster is in ERROR 
state:\n{cluster_name}")
+
+        self.log.info("%s completed successfully.", self.task_id)
+        return event["cluster"]
+
 
 class DataprocScaleClusterOperator(BaseOperator):
     """
@@ -974,7 +1011,7 @@ class DataprocJobBaseOperator(BaseOperator):
 
             if self.deferrable:
                 self.defer(
-                    trigger=DataprocBaseTrigger(
+                    trigger=DataprocSubmitTrigger(
                         job_id=job_id,
                         project_id=self.project_id,
                         region=self.region,
@@ -1888,7 +1925,7 @@ class DataprocSubmitJobOperator(BaseOperator):
         self.job_id = new_job_id
         if self.deferrable:
             self.defer(
-                trigger=DataprocBaseTrigger(
+                trigger=DataprocSubmitTrigger(
                     job_id=self.job_id,
                     project_id=self.project_id,
                     region=self.region,
@@ -1964,6 +2001,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode.
+    :param polling_interval_seconds: Time (seconds) to wait between calls to 
check the run status.
     """
 
     template_fields: Sequence[str] = (
@@ -1991,9 +2030,13 @@ class DataprocUpdateClusterOperator(BaseOperator):
         metadata: Sequence[tuple[str, str]] = (),
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = False,
+        polling_interval_seconds: int = 10,
         **kwargs,
     ):
         super().__init__(**kwargs)
+        if deferrable and polling_interval_seconds <= 0:
+            raise ValueError("Invalid value for polling_interval_seconds. 
Expected value greater than 0")
         self.project_id = project_id
         self.region = region
         self.cluster_name = cluster_name
@@ -2006,6 +2049,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
         self.metadata = metadata
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.polling_interval_seconds = polling_interval_seconds
 
     def execute(self, context: Context):
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
@@ -2026,9 +2071,36 @@ class DataprocUpdateClusterOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        operation.result()
+
+        if not self.deferrable:
+            hook.wait_for_operation(timeout=self.timeout, 
result_retry=self.retry, operation=operation)
+        else:
+            self.defer(
+                trigger=DataprocClusterTrigger(
+                    cluster_name=self.cluster_name,
+                    project_id=self.project_id,
+                    region=self.region,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    polling_interval_seconds=self.polling_interval_seconds,
+                ),
+                method_name="execute_complete",
+            )
         self.log.info("Updated %s cluster.", self.cluster_name)
 
+    def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        cluster_state = event["cluster_state"]
+        cluster_name = event["cluster_name"]
+
+        if cluster_state == ClusterStatus.State.ERROR:
+            raise AirflowException(f"Cluster is in ERROR 
state:\n{cluster_name}")
+        self.log.info("%s completed successfully.", self.task_id)
+
 
 class DataprocCreateBatchOperator(BaseOperator):
     """
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py 
b/airflow/providers/google/cloud/triggers/dataproc.py
index baaec28892..48d3666f4a 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -20,16 +20,16 @@ from __future__ import annotations
 
 import asyncio
 import warnings
-from typing import Sequence
+from typing import Any, AsyncIterator, Sequence
 
-from google.cloud.dataproc_v1 import JobStatus
+from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
 
 from airflow import AirflowException
 from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
-class DataprocBaseTrigger(BaseTrigger):
+class DataprocSubmitTrigger(BaseTrigger):
     """
     Trigger that periodically polls information from Dataproc API to verify 
job status.
     Implementation leverages asynchronous transport.
@@ -65,7 +65,7 @@ class DataprocBaseTrigger(BaseTrigger):
 
     def serialize(self):
         return (
-            
"airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
+            
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
             {
                 "job_id": self.job_id,
                 "project_id": self.project_id,
@@ -89,3 +89,63 @@ class DataprocBaseTrigger(BaseTrigger):
                     raise AirflowException(f"Dataproc job execution failed 
{self.job_id}")
             await asyncio.sleep(self.polling_interval_seconds)
         yield TriggerEvent({"job_id": self.job_id, "job_state": state})
+
+
+class DataprocClusterTrigger(BaseTrigger):
+    """
+    Trigger that periodically polls information from Dataproc API to verify 
status.
+    Implementation leverages asynchronous transport.
+    """
+
+    def __init__(
+        self,
+        cluster_name: str,
+        region: str,
+        project_id: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        polling_interval_seconds: int = 10,
+    ):
+        super().__init__()
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.cluster_name = cluster_name
+        self.project_id = project_id
+        self.region = region
+        self.polling_interval_seconds = polling_interval_seconds
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger",
+            {
+                "cluster_name": self.cluster_name,
+                "project_id": self.project_id,
+                "region": self.region,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "polling_interval_seconds": self.polling_interval_seconds,
+            },
+        )
+
+    async def run(self) -> AsyncIterator["TriggerEvent"]:
+        hook = self._get_hook()
+        while True:
+            cluster = await hook.get_cluster(
+                project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+            )
+            state = cluster.status.state
+            self.log.info("Dataproc cluster: %s is in state: %s", 
self.cluster_name, state)
+            if state in (
+                ClusterStatus.State.ERROR,
+                ClusterStatus.State.RUNNING,
+            ):
+                break
+            self.log.info("Sleeping for %s seconds.", 
self.polling_interval_seconds)
+            await asyncio.sleep(self.polling_interval_seconds)
+        yield TriggerEvent({"cluster_name": self.cluster_name, 
"cluster_state": state, "cluster": cluster})
+
+    def _get_hook(self) -> DataprocAsyncHook:
+        return DataprocAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index 9c2508410f..d613031042 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -75,6 +75,14 @@ With this configuration we can create the cluster:
     :start-after: [START how_to_cloud_dataproc_create_cluster_operator_in_gke]
     :end-before: [END how_to_cloud_dataproc_create_cluster_operator_in_gke]
 
+You can use deferrable mode for this action in order to run the operator 
asynchronously:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_dataproc_create_cluster_operator_async]
+    :end-before: [END how_to_cloud_dataproc_create_cluster_operator_async]
+
 Generating Cluster Config
 ^^^^^^^^^^^^^^^^^^^^^^^^^
 You can also generate **CLUSTER_CONFIG** using functional API,
@@ -111,6 +119,14 @@ To update a cluster you can use:
     :start-after: [START how_to_cloud_dataproc_update_cluster_operator]
     :end-before: [END how_to_cloud_dataproc_update_cluster_operator]
 
+You can use deferrable mode for this action in order to run the operator 
asynchronously:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
+    :end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]
+
 Deleting a cluster
 ------------------
 
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 06e5644c82..70ebd991a8 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,7 +54,7 @@ from airflow.providers.google.cloud.operators.dataproc import 
(
     DataprocSubmitSparkSqlJobOperator,
     DataprocUpdateClusterOperator,
 )
-from airflow.providers.google.cloud.triggers.dataproc import 
DataprocBaseTrigger
+from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger, DataprocSubmitTrigger
 from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.utils.timezone import datetime
@@ -436,7 +436,6 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute(self, mock_hook, to_dict_mock):
         self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
-        mock_hook.return_value.create_cluster.result.return_value = None
         create_cluster_args = {
             "region": GCP_REGION,
             "project_id": GCP_PROJECT,
@@ -474,7 +473,7 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
         # Test whether xcom push occurs before create cluster is called
         self.extra_links_manager_mock.assert_has_calls(expected_calls, 
any_order=False)
 
-        
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+        to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
         self.mock_ti.xcom_push.assert_called_once_with(
             key="conf",
             value=DATAPROC_CLUSTER_CONF_EXPECTED,
@@ -485,7 +484,7 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute_in_gke(self, mock_hook, to_dict_mock):
         self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
-        mock_hook.return_value.create_cluster.result.return_value = None
+        mock_hook.return_value.create_cluster.return_value = None
         create_cluster_args = {
             "region": GCP_REGION,
             "project_id": GCP_PROJECT,
@@ -523,7 +522,7 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
         # Test whether xcom push occurs before create cluster is called
         self.extra_links_manager_mock.assert_has_calls(expected_calls, 
any_order=False)
 
-        
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+        to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
         self.mock_ti.xcom_push.assert_called_once_with(
             key="conf",
             value=DATAPROC_CLUSTER_CONF_EXPECTED,
@@ -664,6 +663,51 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
             region=GCP_REGION, project_id=GCP_PROJECT, 
cluster_name=CLUSTER_NAME
         )
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_defer_method(self, mock_trigger_hook, 
mock_hook):
+        mock_hook.return_value.create_cluster.return_value = None
+        operator = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_config=CONFIG,
+            labels=LABELS,
+            cluster_name=CLUSTER_NAME,
+            delete_on_error=True,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.create_cluster.assert_called_once_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_config=CONFIG,
+            request_id=None,
+            labels=LABELS,
+            cluster_name=CLUSTER_NAME,
+            virtual_cluster_config=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+        assert isinstance(exc.value.trigger, DataprocClusterTrigger)
+        assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
 
 @pytest.mark.need_serialized_dag
 def test_create_cluster_operator_extra_links(dag_maker, 
create_task_instance_of_operator):
@@ -961,7 +1005,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
 
         self.mock_ti.xcom_push.assert_not_called()
 
-        assert isinstance(exc.value.trigger, DataprocBaseTrigger)
+        assert isinstance(exc.value.trigger, DataprocSubmitTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1151,6 +1195,50 @@ class 
TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
             )
             op.execute(context=self.mock_context)
 
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_defer_method(self, mock_trigger_hook, 
mock_hook):
+        mock_hook.return_value.update_cluster.return_value = None
+        operator = DataprocUpdateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            cluster_name=CLUSTER_NAME,
+            cluster=CLUSTER,
+            update_mask=UPDATE_MASK,
+            request_id=REQUEST_ID,
+            graceful_decommission_timeout={"graceful_decommission_timeout": 
"600s"},
+            project_id=GCP_PROJECT,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.update_cluster.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_REGION,
+            cluster_name=CLUSTER_NAME,
+            cluster=CLUSTER,
+            update_mask=UPDATE_MASK,
+            request_id=REQUEST_ID,
+            graceful_decommission_timeout={"graceful_decommission_timeout": 
"600s"},
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+        assert isinstance(exc.value.trigger, DataprocClusterTrigger)
+        assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
 
 @pytest.mark.need_serialized_dag
 def test_update_cluster_operator_extra_links(dag_maker, 
create_task_instance_of_operator):
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py 
b/tests/providers/google/cloud/triggers/test_dataproc.py
new file mode 100644
index 0000000000..854c02c0de
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+
+import pytest
+from google.cloud.dataproc_v1 import ClusterStatus
+
+from airflow.providers.google.cloud.triggers.dataproc import 
DataprocClusterTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.google.cloud.utils.compat import async_mock
+
+TEST_PROJECT_ID = "project-id"
+TEST_REGION = "region"
+TEST_CLUSTER_NAME = "cluster_name"
+TEST_POLL_INTERVAL = 5
+TEST_GCP_CONN_ID = "google_cloud_default"
+
+
[email protected]
+def trigger():
+    return DataprocClusterTrigger(
+        cluster_name=TEST_CLUSTER_NAME,
+        project_id=TEST_PROJECT_ID,
+        region=TEST_REGION,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        impersonation_chain=None,
+        polling_interval_seconds=TEST_POLL_INTERVAL,
+    )
+
+
[email protected]()
+def async_get_cluster():
+    def func(**kwargs):
+        m = async_mock.MagicMock()
+        m.configure_mock(**kwargs)
+        f = asyncio.Future()
+        f.set_result(m)
+        return f
+
+    return func
+
+
+class TestDataprocClusterTrigger:
+    def 
test_async_cluster_trigger_serialization_should_execute_successfully(self, 
trigger):
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger"
+        assert kwargs == {
+            "cluster_name": TEST_CLUSTER_NAME,
+            "project_id": TEST_PROJECT_ID,
+            "region": TEST_REGION,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+            "impersonation_chain": None,
+            "polling_interval_seconds": TEST_POLL_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+    async def 
test_async_cluster_triggers_on_success_should_execute_successfully(
+        self, mock_hook, trigger, async_get_cluster
+    ):
+        mock_hook.return_value = async_get_cluster(
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            cluster_name=TEST_CLUSTER_NAME,
+            status=ClusterStatus(state=ClusterStatus.State.RUNNING),
+        )
+
+        generator = trigger.run()
+        actual_event = await generator.asend(None)
+
+        expected_event = TriggerEvent(
+            {
+                "cluster_name": TEST_CLUSTER_NAME,
+                "cluster_state": ClusterStatus.State.RUNNING,
+                "cluster": actual_event.payload["cluster"],
+            }
+        )
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+    async def test_async_cluster_trigger_run_returns_error_event(self, 
mock_hook, trigger, async_get_cluster):
+        mock_hook.return_value = async_get_cluster(
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            cluster_name=TEST_CLUSTER_NAME,
+            status=ClusterStatus(state=ClusterStatus.State.ERROR),
+        )
+
+        actual_event = await (trigger.run()).asend(None)
+        await asyncio.sleep(0.5)
+
+        expected_event = TriggerEvent(
+            {
+                "cluster_name": TEST_CLUSTER_NAME,
+                "cluster_state": ClusterStatus.State.ERROR,
+                "cluster": actual_event.payload["cluster"],
+            }
+        )
+        assert expected_event == actual_event
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+    async def test_cluster_run_loop_is_still_running(self, mock_hook, trigger, 
caplog, async_get_cluster):
+        mock_hook.return_value = async_get_cluster(
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            cluster_name=TEST_CLUSTER_NAME,
+            status=ClusterStatus(state=ClusterStatus.State.CREATING),
+        )
+
+        caplog.set_level(logging.INFO)
+
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        assert not task.done()
+        assert f"Current state is: {ClusterStatus.State.CREATING}"
+        assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
diff --git 
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
new file mode 100644
index 0000000000..df4e5a1f2a
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocUpdateClusterOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateClusterOperator,
+    DataprocDeleteClusterOperator,
+    DataprocUpdateClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_update"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
+
+CLUSTER_NAME = f"cluster-dataproc-update-{ENV_ID}"
+REGION = "europe-west1"
+ZONE = "europe-west1-b"
+
+
+# Cluster definition
+CLUSTER_CONFIG = {
+    "master_config": {
+        "num_instances": 1,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
1024},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
1024},
+    },
+}
+
+# Update options
+# [START how_to_cloud_dataproc_updatemask_cluster_operator]
+CLUSTER_UPDATE = {
+    "config": {"worker_config": {"num_instances": 3}, 
"secondary_worker_config": {"num_instances": 3}}
+}
+UPDATE_MASK = {
+    "paths": ["config.worker_config.num_instances", 
"config.secondary_worker_config.num_instances"]
+}
+# [END how_to_cloud_dataproc_updatemask_cluster_operator]
+
+TIMEOUT = {"seconds": 1 * 24 * 60 * 60}
+
+
+with models.DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "dataproc"],
+) as dag:
+    # [START how_to_cloud_dataproc_create_cluster_operator_async]
+    create_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+        deferrable=True,
+    )
+    # [END how_to_cloud_dataproc_create_cluster_operator_async]
+
+    # [START how_to_cloud_dataproc_update_cluster_operator_async]
+    update_cluster = DataprocUpdateClusterOperator(
+        task_id="update_cluster",
+        cluster_name=CLUSTER_NAME,
+        cluster=CLUSTER_UPDATE,
+        update_mask=UPDATE_MASK,
+        graceful_decommission_timeout=TIMEOUT,
+        project_id=PROJECT_ID,
+        region=REGION,
+        deferrable=True,
+    )
+    # [END how_to_cloud_dataproc_update_cluster_operator_async]
+
+    delete_cluster = DataprocDeleteClusterOperator(
+        task_id="delete_cluster",
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        region=REGION,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    create_cluster >> update_cluster >> delete_cluster
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to