This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9fd80130e2 Add deferrable mode to DataprocCreateClusterOperator and
DataprocUpdateClusterOperator (#28529)
9fd80130e2 is described below
commit 9fd80130e2351c7ec31bbeb6c10f6b11708b318b
Author: Beata Kossakowska <[email protected]>
AuthorDate: Wed Jan 25 23:20:22 2023 +0100
Add deferrable mode to DataprocCreateClusterOperator and
DataprocUpdateClusterOperator (#28529)
Co-authored-by: Beata Kossakowska <[email protected]>
---
.../providers/google/cloud/operators/dataproc.py | 96 +++++++++++++--
.../providers/google/cloud/triggers/dataproc.py | 68 ++++++++++-
.../operators/cloud/dataproc.rst | 16 +++
.../google/cloud/operators/test_dataproc.py | 100 ++++++++++++++-
.../google/cloud/triggers/test_dataproc.py | 136 +++++++++++++++++++++
.../example_dataproc_cluster_deferrable.py | 121 ++++++++++++++++++
6 files changed, 515 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index d2899037ae..316a15981f 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -26,13 +26,13 @@ import time
import uuid
import warnings
from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from google.api_core import operation # type: ignore
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.retry import Retry, exponential_sleep_generator
-from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
+from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from google.protobuf.duration_pb2 import Duration
from google.protobuf.field_mask_pb2 import FieldMask
@@ -50,7 +50,7 @@ from airflow.providers.google.cloud.links.dataproc import (
DataprocLink,
DataprocListLink,
)
-from airflow.providers.google.cloud.triggers.dataproc import
DataprocBaseTrigger
+from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger, DataprocSubmitTrigger
from airflow.utils import timezone
if TYPE_CHECKING:
@@ -438,6 +438,8 @@ class DataprocCreateClusterOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode.
+ :param polling_interval_seconds: Time (seconds) to wait between calls to
check the run status.
"""
template_fields: Sequence[str] = (
@@ -470,6 +472,8 @@ class DataprocCreateClusterOperator(BaseOperator):
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = False,
+ polling_interval_seconds: int = 10,
**kwargs,
) -> None:
@@ -502,7 +506,8 @@ class DataprocCreateClusterOperator(BaseOperator):
del kwargs[arg]
super().__init__(**kwargs)
-
+ if deferrable and polling_interval_seconds <= 0:
+ raise ValueError("Invalid value for polling_interval_seconds.
Expected value greater than 0")
self.cluster_config = cluster_config
self.cluster_name = cluster_name
self.labels = labels
@@ -517,9 +522,11 @@ class DataprocCreateClusterOperator(BaseOperator):
self.use_if_exists = use_if_exists
self.impersonation_chain = impersonation_chain
self.virtual_cluster_config = virtual_cluster_config
+ self.deferrable = deferrable
+ self.polling_interval_seconds = polling_interval_seconds
def _create_cluster(self, hook: DataprocHook):
- operation = hook.create_cluster(
+ return hook.create_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
@@ -531,9 +538,6 @@ class DataprocCreateClusterOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- cluster = operation.result()
- self.log.info("Cluster created.")
- return cluster
def _delete_cluster(self, hook):
self.log.info("Deleting the cluster")
@@ -596,7 +600,25 @@ class DataprocCreateClusterOperator(BaseOperator):
)
try:
# First try to create a new cluster
- cluster = self._create_cluster(hook)
+ operation = self._create_cluster(hook)
+ if not self.deferrable:
+ cluster = hook.wait_for_operation(
+ timeout=self.timeout, result_retry=self.retry,
operation=operation
+ )
+ self.log.info("Cluster created.")
+ return Cluster.to_dict(cluster)
+ else:
+ self.defer(
+ trigger=DataprocClusterTrigger(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
except AlreadyExists:
if not self.use_if_exists:
raise
@@ -618,6 +640,21 @@ class DataprocCreateClusterOperator(BaseOperator):
return Cluster.to_dict(cluster)
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ cluster_state = event["cluster_state"]
+ cluster_name = event["cluster_name"]
+
+ if cluster_state == ClusterStatus.State.ERROR:
+ raise AirflowException(f"Cluster is in ERROR
state:\n{cluster_name}")
+
+ self.log.info("%s completed successfully.", self.task_id)
+ return event["cluster"]
+
class DataprocScaleClusterOperator(BaseOperator):
"""
@@ -974,7 +1011,7 @@ class DataprocJobBaseOperator(BaseOperator):
if self.deferrable:
self.defer(
- trigger=DataprocBaseTrigger(
+ trigger=DataprocSubmitTrigger(
job_id=job_id,
project_id=self.project_id,
region=self.region,
@@ -1888,7 +1925,7 @@ class DataprocSubmitJobOperator(BaseOperator):
self.job_id = new_job_id
if self.deferrable:
self.defer(
- trigger=DataprocBaseTrigger(
+ trigger=DataprocSubmitTrigger(
job_id=self.job_id,
project_id=self.project_id,
region=self.region,
@@ -1964,6 +2001,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode.
+ :param polling_interval_seconds: Time (seconds) to wait between calls to
check the run status.
"""
template_fields: Sequence[str] = (
@@ -1991,9 +2030,13 @@ class DataprocUpdateClusterOperator(BaseOperator):
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = False,
+ polling_interval_seconds: int = 10,
**kwargs,
):
super().__init__(**kwargs)
+ if deferrable and polling_interval_seconds <= 0:
+ raise ValueError("Invalid value for polling_interval_seconds.
Expected value greater than 0")
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
@@ -2006,6 +2049,8 @@ class DataprocUpdateClusterOperator(BaseOperator):
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
+ self.polling_interval_seconds = polling_interval_seconds
def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
@@ -2026,9 +2071,36 @@ class DataprocUpdateClusterOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
- operation.result()
+
+ if not self.deferrable:
+ hook.wait_for_operation(timeout=self.timeout,
result_retry=self.retry, operation=operation)
+ else:
+ self.defer(
+ trigger=DataprocClusterTrigger(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_interval_seconds=self.polling_interval_seconds,
+ ),
+ method_name="execute_complete",
+ )
self.log.info("Updated %s cluster.", self.cluster_name)
+ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ cluster_state = event["cluster_state"]
+ cluster_name = event["cluster_name"]
+
+ if cluster_state == ClusterStatus.State.ERROR:
+ raise AirflowException(f"Cluster is in ERROR
state:\n{cluster_name}")
+ self.log.info("%s completed successfully.", self.task_id)
+
class DataprocCreateBatchOperator(BaseOperator):
"""
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py
b/airflow/providers/google/cloud/triggers/dataproc.py
index baaec28892..48d3666f4a 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -20,16 +20,16 @@ from __future__ import annotations
import asyncio
import warnings
-from typing import Sequence
+from typing import Any, AsyncIterator, Sequence
-from google.cloud.dataproc_v1 import JobStatus
+from google.cloud.dataproc_v1 import ClusterStatus, JobStatus
from airflow import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
-class DataprocBaseTrigger(BaseTrigger):
+class DataprocSubmitTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify
job status.
Implementation leverages asynchronous transport.
@@ -65,7 +65,7 @@ class DataprocBaseTrigger(BaseTrigger):
def serialize(self):
return (
-
"airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger",
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger",
{
"job_id": self.job_id,
"project_id": self.project_id,
@@ -89,3 +89,63 @@ class DataprocBaseTrigger(BaseTrigger):
raise AirflowException(f"Dataproc job execution failed
{self.job_id}")
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state})
+
+
+class DataprocClusterTrigger(BaseTrigger):
+ """
+ Trigger that periodically polls information from Dataproc API to verify
status.
+ Implementation leverages asynchronous transport.
+ """
+
+ def __init__(
+ self,
+ cluster_name: str,
+ region: str,
+ project_id: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ polling_interval_seconds: int = 10,
+ ):
+ super().__init__()
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.cluster_name = cluster_name
+ self.project_id = project_id
+ self.region = region
+ self.polling_interval_seconds = polling_interval_seconds
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger",
+ {
+ "cluster_name": self.cluster_name,
+ "project_id": self.project_id,
+ "region": self.region,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "polling_interval_seconds": self.polling_interval_seconds,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]:
+ hook = self._get_hook()
+ while True:
+ cluster = await hook.get_cluster(
+ project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
+ )
+ state = cluster.status.state
+ self.log.info("Dataproc cluster: %s is in state: %s",
self.cluster_name, state)
+ if state in (
+ ClusterStatus.State.ERROR,
+ ClusterStatus.State.RUNNING,
+ ):
+ break
+ self.log.info("Sleeping for %s seconds.",
self.polling_interval_seconds)
+ await asyncio.sleep(self.polling_interval_seconds)
+ yield TriggerEvent({"cluster_name": self.cluster_name,
"cluster_state": state, "cluster": cluster})
+
+ def _get_hook(self) -> DataprocAsyncHook:
+ return DataprocAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index 9c2508410f..d613031042 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -75,6 +75,14 @@ With this configuration we can create the cluster:
:start-after: [START how_to_cloud_dataproc_create_cluster_operator_in_gke]
:end-before: [END how_to_cloud_dataproc_create_cluster_operator_in_gke]
+You can use deferrable mode for this action in order to run the operator
asynchronously:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_dataproc_create_cluster_operator_async]
+ :end-before: [END how_to_cloud_dataproc_create_cluster_operator_async]
+
Generating Cluster Config
^^^^^^^^^^^^^^^^^^^^^^^^^
You can also generate **CLUSTER_CONFIG** using functional API,
@@ -111,6 +119,14 @@ To update a cluster you can use:
:start-after: [START how_to_cloud_dataproc_update_cluster_operator]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator]
+You can use deferrable mode for this action in order to run the operator
asynchronously:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
+ :end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]
+
Deleting a cluster
------------------
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 06e5644c82..70ebd991a8 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,7 +54,7 @@ from airflow.providers.google.cloud.operators.dataproc import
(
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
)
-from airflow.providers.google.cloud.triggers.dataproc import
DataprocBaseTrigger
+from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger, DataprocSubmitTrigger
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
@@ -436,7 +436,6 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook, to_dict_mock):
self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
- mock_hook.return_value.create_cluster.result.return_value = None
create_cluster_args = {
"region": GCP_REGION,
"project_id": GCP_PROJECT,
@@ -474,7 +473,7 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
# Test whether xcom push occurs before create cluster is called
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
-
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+ to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
self.mock_ti.xcom_push.assert_called_once_with(
key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
@@ -485,7 +484,7 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_in_gke(self, mock_hook, to_dict_mock):
self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
- mock_hook.return_value.create_cluster.result.return_value = None
+ mock_hook.return_value.create_cluster.return_value = None
create_cluster_args = {
"region": GCP_REGION,
"project_id": GCP_PROJECT,
@@ -523,7 +522,7 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
# Test whether xcom push occurs before create cluster is called
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
-
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
+ to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
self.mock_ti.xcom_push.assert_called_once_with(
key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
@@ -664,6 +663,51 @@ class
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
region=GCP_REGION, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME
)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_create_execute_call_defer_method(self, mock_trigger_hook,
mock_hook):
+ mock_hook.return_value.create_cluster.return_value = None
+ operator = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_config=CONFIG,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
+ delete_on_error=True,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.create_cluster.assert_called_once_with(
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_config=CONFIG,
+ request_id=None,
+ labels=LABELS,
+ cluster_name=CLUSTER_NAME,
+ virtual_cluster_config=None,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+ mock_hook.return_value.wait_for_operation.assert_not_called()
+ assert isinstance(exc.value.trigger, DataprocClusterTrigger)
+ assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
@pytest.mark.need_serialized_dag
def test_create_cluster_operator_extra_links(dag_maker,
create_task_instance_of_operator):
@@ -961,7 +1005,7 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
self.mock_ti.xcom_push.assert_not_called()
- assert isinstance(exc.value.trigger, DataprocBaseTrigger)
+ assert isinstance(exc.value.trigger, DataprocSubmitTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1151,6 +1195,50 @@ class
TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
)
op.execute(context=self.mock_context)
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+ def test_create_execute_call_defer_method(self, mock_trigger_hook,
mock_hook):
+ mock_hook.return_value.update_cluster.return_value = None
+ operator = DataprocUpdateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ cluster_name=CLUSTER_NAME,
+ cluster=CLUSTER,
+ update_mask=UPDATE_MASK,
+ request_id=REQUEST_ID,
+ graceful_decommission_timeout={"graceful_decommission_timeout":
"600s"},
+ project_id=GCP_PROJECT,
+ gcp_conn_id=GCP_CONN_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(mock.MagicMock())
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ mock_hook.return_value.update_cluster.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ region=GCP_REGION,
+ cluster_name=CLUSTER_NAME,
+ cluster=CLUSTER,
+ update_mask=UPDATE_MASK,
+ request_id=REQUEST_ID,
+ graceful_decommission_timeout={"graceful_decommission_timeout":
"600s"},
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.wait_for_operation.assert_not_called()
+ assert isinstance(exc.value.trigger, DataprocClusterTrigger)
+ assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
@pytest.mark.need_serialized_dag
def test_update_cluster_operator_extra_links(dag_maker,
create_task_instance_of_operator):
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py
b/tests/providers/google/cloud/triggers/test_dataproc.py
new file mode 100644
index 0000000000..854c02c0de
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+
+import pytest
+from google.cloud.dataproc_v1 import ClusterStatus
+
+from airflow.providers.google.cloud.triggers.dataproc import
DataprocClusterTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.google.cloud.utils.compat import async_mock
+
+TEST_PROJECT_ID = "project-id"
+TEST_REGION = "region"
+TEST_CLUSTER_NAME = "cluster_name"
+TEST_POLL_INTERVAL = 5
+TEST_GCP_CONN_ID = "google_cloud_default"
+
+
[email protected]
+def trigger():
+ return DataprocClusterTrigger(
+ cluster_name=TEST_CLUSTER_NAME,
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=None,
+ polling_interval_seconds=TEST_POLL_INTERVAL,
+ )
+
+
[email protected]()
+def async_get_cluster():
+ def func(**kwargs):
+ m = async_mock.MagicMock()
+ m.configure_mock(**kwargs)
+ f = asyncio.Future()
+ f.set_result(m)
+ return f
+
+ return func
+
+
+class TestDataprocClusterTrigger:
+ def
test_async_cluster_trigger_serialization_should_execute_successfully(self,
trigger):
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger"
+ assert kwargs == {
+ "cluster_name": TEST_CLUSTER_NAME,
+ "project_id": TEST_PROJECT_ID,
+ "region": TEST_REGION,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": None,
+ "polling_interval_seconds": TEST_POLL_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+ async def
test_async_cluster_triggers_on_success_should_execute_successfully(
+ self, mock_hook, trigger, async_get_cluster
+ ):
+ mock_hook.return_value = async_get_cluster(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ cluster_name=TEST_CLUSTER_NAME,
+ status=ClusterStatus(state=ClusterStatus.State.RUNNING),
+ )
+
+ generator = trigger.run()
+ actual_event = await generator.asend(None)
+
+ expected_event = TriggerEvent(
+ {
+ "cluster_name": TEST_CLUSTER_NAME,
+ "cluster_state": ClusterStatus.State.RUNNING,
+ "cluster": actual_event.payload["cluster"],
+ }
+ )
+ assert expected_event == actual_event
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+ async def test_async_cluster_trigger_run_returns_error_event(self,
mock_hook, trigger, async_get_cluster):
+ mock_hook.return_value = async_get_cluster(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ cluster_name=TEST_CLUSTER_NAME,
+ status=ClusterStatus(state=ClusterStatus.State.ERROR),
+ )
+
+ actual_event = await (trigger.run()).asend(None)
+ await asyncio.sleep(0.5)
+
+ expected_event = TriggerEvent(
+ {
+ "cluster_name": TEST_CLUSTER_NAME,
+ "cluster_state": ClusterStatus.State.ERROR,
+ "cluster": actual_event.payload["cluster"],
+ }
+ )
+ assert expected_event == actual_event
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
+ async def test_cluster_run_loop_is_still_running(self, mock_hook, trigger,
caplog, async_get_cluster):
+ mock_hook.return_value = async_get_cluster(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ cluster_name=TEST_CLUSTER_NAME,
+ status=ClusterStatus(state=ClusterStatus.State.CREATING),
+ )
+
+ caplog.set_level(logging.INFO)
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ assert not task.done()
+ assert f"Current state is: {ClusterStatus.State.CREATING}"
+ assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
diff --git
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
new file mode 100644
index 0000000000..df4e5a1f2a
--- /dev/null
+++
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_deferrable.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocUpdateClusterOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import models
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocUpdateClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "dataproc_update"
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "")
+
+CLUSTER_NAME = f"cluster-dataproc-update-{ENV_ID}"
+REGION = "europe-west1"
+ZONE = "europe-west1-b"
+
+
+# Cluster definition
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
1024},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
1024},
+ },
+}
+
+# Update options
+# [START how_to_cloud_dataproc_updatemask_cluster_operator]
+CLUSTER_UPDATE = {
+ "config": {"worker_config": {"num_instances": 3},
"secondary_worker_config": {"num_instances": 3}}
+}
+UPDATE_MASK = {
+ "paths": ["config.worker_config.num_instances",
"config.secondary_worker_config.num_instances"]
+}
+# [END how_to_cloud_dataproc_updatemask_cluster_operator]
+
+TIMEOUT = {"seconds": 1 * 24 * 60 * 60}
+
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "dataproc"],
+) as dag:
+ # [START how_to_cloud_dataproc_create_cluster_operator_async]
+ create_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ deferrable=True,
+ )
+ # [END how_to_cloud_dataproc_create_cluster_operator_async]
+
+ # [START how_to_cloud_dataproc_update_cluster_operator_async]
+ update_cluster = DataprocUpdateClusterOperator(
+ task_id="update_cluster",
+ cluster_name=CLUSTER_NAME,
+ cluster=CLUSTER_UPDATE,
+ update_mask=UPDATE_MASK,
+ graceful_decommission_timeout=TIMEOUT,
+ project_id=PROJECT_ID,
+ region=REGION,
+ deferrable=True,
+ )
+ # [END how_to_cloud_dataproc_update_cluster_operator_async]
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ cluster_name=CLUSTER_NAME,
+ region=REGION,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ create_cluster >> update_cluster >> delete_cluster
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)