This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0f2670e7ac Create DataprocStartClusterOperator and
DataprocStopClusterOperator (#36996)
0f2670e7ac is described below
commit 0f2670e7acaabb7110dd800b42b491aac9a8a511
Author: M. Olcay Tercanlı <[email protected]>
AuthorDate: Fri Jan 26 18:12:13 2024 +0000
Create DataprocStartClusterOperator and DataprocStopClusterOperator (#36996)
---
airflow/providers/google/cloud/hooks/dataproc.py | 88 +++++++++
.../providers/google/cloud/operators/dataproc.py | 197 +++++++++++++++++++++
.../operators/cloud/dataproc.rst | 24 +++
tests/always/test_project_structure.py | 1 +
.../providers/google/cloud/hooks/test_dataproc.py | 42 +++++
.../google/cloud/operators/test_dataproc.py | 86 +++++++++
...proc_cluster_create_existing_stopped_cluster.py | 120 +++++++++++++
.../example_dataproc_cluster_start_stop.py | 114 ++++++++++++
8 files changed, 672 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/dataproc.py
b/airflow/providers/google/cloud/hooks/dataproc.py
index dae5535e40..4551b24384 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -583,6 +583,94 @@ class DataprocHook(GoogleBaseHook):
)
return operation
+ @GoogleBaseHook.fallback_to_default_project_id
+ def start_cluster(
+ self,
+ region: str,
+ project_id: str,
+ cluster_name: str,
+ cluster_uuid: str | None = None,
+ request_id: str | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """Start a cluster in a project.
+
+ :param region: Cloud Dataproc region to handle the request.
+ :param project_id: Google Cloud project ID that the cluster belongs to.
+ :param cluster_name: The cluster name.
+ :param cluster_uuid: The cluster UUID
+ :param request_id: A unique id used to identify the request. If the
+ server receives two *UpdateClusterRequest* requests with the same
+ ID, the second request will be ignored, and an operation created
+ for the first one and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If *None*,
requests
+ will not be retried.
+ :param timeout: The amount of time, in seconds, to wait for the request
+ to complete. If *retry* is specified, the timeout applies to each
+ individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :return: An instance of ``google.api_core.operation.Operation``
+ """
+ client = self.get_cluster_client(region=region)
+ return client.start_cluster(
+ request={
+ "project_id": project_id,
+ "region": region,
+ "cluster_name": cluster_name,
+ "cluster_uuid": cluster_uuid,
+ "request_id": request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def stop_cluster(
+ self,
+ region: str,
+ project_id: str,
+ cluster_name: str,
+ cluster_uuid: str | None = None,
+ request_id: str | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Operation:
+ """Start a cluster in a project.
+
+ :param region: Cloud Dataproc region to handle the request.
+ :param project_id: Google Cloud project ID that the cluster belongs to.
+ :param cluster_name: The cluster name.
+ :param cluster_uuid: The cluster UUID
+ :param request_id: A unique id used to identify the request. If the
+ server receives two *UpdateClusterRequest* requests with the same
+ ID, the second request will be ignored, and an operation created
+ for the first one and stored in the backend is returned.
+ :param retry: A retry object used to retry requests. If *None*,
requests
+ will not be retried.
+ :param timeout: The amount of time, in seconds, to wait for the request
+ to complete. If *retry* is specified, the timeout applies to each
+ individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :return: An instance of ``google.api_core.operation.Operation``
+ """
+ client = self.get_cluster_client(region=region)
+ return client.stop_cluster(
+ request={
+ "project_id": project_id,
+ "region": region,
+ "cluster_name": cluster_name,
+ "cluster_uuid": cluster_uuid,
+ "request_id": request_id,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
@GoogleBaseHook.fallback_to_default_project_id
def create_workflow_template(
self,
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 7f3fcd5d01..aacc1adb24 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -724,6 +724,17 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
cluster = self._get_cluster(hook)
return cluster
+ def _start_cluster(self, hook: DataprocHook):
+ op: operation.Operation = hook.start_cluster(
+ region=self.region,
+ project_id=self.project_id,
+ cluster_name=self.cluster_name,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ return hook.wait_for_operation(timeout=self.timeout,
result_retry=self.retry, operation=op)
+
def execute(self, context: Context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
@@ -801,6 +812,9 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)
+ elif cluster.status.state == cluster.status.State.STOPPED:
+ # if the cluster exists and already stopped, then start the cluster
+ self._start_cluster(hook)
return Cluster.to_dict(cluster)
@@ -1082,6 +1096,189 @@ class
DataprocDeleteClusterOperator(GoogleCloudBaseOperator):
)
+class _DataprocStartStopClusterBaseOperator(GoogleCloudBaseOperator):
+ """Base class to start or stop a cluster in a project.
+
+ :param cluster_name: Required. Name of the cluster to create
+ :param region: Required. The specified region where the dataproc cluster
is created.
+ :param project_id: Optional. The ID of the Google Cloud project the
cluster belongs to.
+ :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the
RPC should fail
+ if cluster with specified UUID does not exist.
+ :param request_id: Optional. A unique id used to identify the request. If
the server receives two
+ ``DeleteClusterRequest`` requests with the same id, then the second
request will be ignored and the
+ first ``google.longrunning.Operation`` created and stored in the
backend is returned.
+ :param retry: A retry object used to retry requests. If ``None`` is
specified, requests will not be
+ retried.
+ :param timeout: The amount of time, in seconds, to wait for the request to
complete. Note that if
+ ``retry`` is specified, the timeout applies to each individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "cluster_name",
+ "region",
+ "project_id",
+ "request_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ cluster_name: str,
+ region: str,
+ project_id: str | None = None,
+ cluster_uuid: str | None = None,
+ request_id: str | None = None,
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
+ timeout: float = 1 * 60 * 60,
+ metadata: Sequence[tuple[str, str]] = (),
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.cluster_name = cluster_name
+ self.cluster_uuid = cluster_uuid
+ self.request_id = request_id
+ self.retry = retry
+ self.timeout = timeout
+ self.metadata = metadata
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self._hook: DataprocHook | None = None
+
+ @property
+ def hook(self):
+ if self._hook is None:
+ self._hook = DataprocHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+ return self._hook
+
+ def _get_project_id(self) -> str:
+ return self.project_id or self.hook.project_id
+
+ def _get_cluster(self) -> Cluster:
+ """Retrieve the cluster information.
+
+ :return: Instance of ``google.cloud.dataproc_v1.Cluster``` class
+ """
+ return self.hook.get_cluster(
+ project_id=self._get_project_id(),
+ region=self.region,
+ cluster_name=self.cluster_name,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+
+ def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool,
str | None]:
+ """Implement this method in child class to return whether the cluster
is in desired state or not.
+
+ If the cluster is in desired stated you can return a log message
content as a second value
+ for the return tuple.
+
+ :param cluster: Required. Instance of
``google.cloud.dataproc_v1.Cluster``
+ class to interact with Dataproc API
+ :return: Tuple of (Boolean, Optional[str]) The first value of the
tuple is whether the cluster is
+ in desired state or not. The second value of the tuple will use if
you want to log something when
+ the cluster is in desired state already.
+ """
+ raise NotImplementedError
+
+ def _get_operation(self) -> operation.Operation:
+ """Implement this method in child class to call the related hook
method and return its result.
+
+ :return: ``google.api_core.operation.Operation`` value whether the
cluster is in desired state or not
+ """
+ raise NotImplementedError
+
+ def execute(self, context: Context) -> dict | None:
+ cluster: Cluster = self._get_cluster()
+ is_already_desired_state, log_str =
self._check_desired_cluster_state(cluster)
+ if is_already_desired_state:
+ self.log.info(log_str)
+ return None
+
+ op: operation.Operation = self._get_operation()
+ result = self.hook.wait_for_operation(timeout=self.timeout,
result_retry=self.retry, operation=op)
+ return Cluster.to_dict(result)
+
+
+class DataprocStartClusterOperator(_DataprocStartStopClusterBaseOperator):
+ """Start a cluster in a project."""
+
+ operator_extra_links = (DataprocClusterLink(),)
+
+ def execute(self, context: Context) -> dict | None:
+ self.log.info("Starting the cluster: %s", self.cluster_name)
+ cluster = super().execute(context)
+ DataprocClusterLink.persist(
+ context=context,
+ operator=self,
+ cluster_id=self.cluster_name,
+ project_id=self._get_project_id(),
+ region=self.region,
+ )
+ self.log.info("Cluster started")
+ return cluster
+
+ def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool,
str | None]:
+ if cluster.status.state == cluster.status.State.RUNNING:
+ return True, f'The cluster "{self.cluster_name}" already running!'
+ return False, None
+
+ def _get_operation(self) -> operation.Operation:
+ return self.hook.start_cluster(
+ region=self.region,
+ project_id=self._get_project_id(),
+ cluster_name=self.cluster_name,
+ cluster_uuid=self.cluster_uuid,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+
+
+class DataprocStopClusterOperator(_DataprocStartStopClusterBaseOperator):
+ """Stop a cluster in a project."""
+
+ def execute(self, context: Context) -> dict | None:
+ self.log.info("Stopping the cluster: %s", self.cluster_name)
+ cluster = super().execute(context)
+ self.log.info("Cluster stopped")
+ return cluster
+
+ def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool,
str | None]:
+ if cluster.status.state in [cluster.status.State.STOPPED,
cluster.status.State.STOPPING]:
+ return True, f'The cluster "{self.cluster_name}" already stopped!'
+ return False, None
+
+ def _get_operation(self) -> operation.Operation:
+ return self.hook.stop_cluster(
+ region=self.region,
+ project_id=self._get_project_id(),
+ cluster_name=self.cluster_name,
+ cluster_uuid=self.cluster_uuid,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+
+
class DataprocJobBaseOperator(GoogleCloudBaseOperator):
"""Base class for operators that launch job on DataProc.
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index 67c2831a1a..6277f94e05 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -201,6 +201,30 @@ You can use deferrable mode for this action in order to
run the operator asynchr
:start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]
+Starting a cluster
+---------------------------
+
+To start a cluster you can use the
+:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_dataproc_start_cluster_operator]
+ :end-before: [END how_to_cloud_dataproc_start_cluster_operator]
+
+Stopping a cluster
+---------------------------
+
+To stop a cluster you can use the
+:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_dataproc_stop_cluster_operator]
+ :end-before: [END how_to_cloud_dataproc_stop_cluster_operator]
+
Deleting a cluster
------------------
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index db026aa6bf..bab56abead 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -403,6 +403,7 @@ class
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator",
"airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator",
+
"airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator",
}
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py
b/tests/providers/google/cloud/hooks/test_dataproc.py
index 1a82fc8a1c..131f5a342b 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -287,6 +287,48 @@ class TestDataprocHook:
update_mask="update-mask",
)
+ @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
+ def test_start_cluster(self, mock_client):
+ self.hook.start_cluster(
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ )
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.start_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_uuid=None,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
+ @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
+ def test_stop_cluster(self, mock_client):
+ self.hook.stop_cluster(
+ region=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ )
+ mock_client.assert_called_once_with(region=GCP_LOCATION)
+ mock_client.return_value.stop_cluster.assert_called_once_with(
+ request=dict(
+ project_id=GCP_PROJECT,
+ region=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ cluster_uuid=None,
+ request_id=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+
@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_create_workflow_template(self, mock_client):
template = {"test": "test"}
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index d0b04a6fa9..44e20489a2 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,6 +54,8 @@ from airflow.providers.google.cloud.operators.dataproc import
(
DataprocLink,
DataprocListBatchesOperator,
DataprocScaleClusterOperator,
+ DataprocStartClusterOperator,
+ DataprocStopClusterOperator,
DataprocSubmitHadoopJobOperator,
DataprocSubmitHiveJobOperator,
DataprocSubmitJobOperator,
@@ -1683,6 +1685,90 @@ def test_update_cluster_operator_extra_links(dag_maker,
create_task_instance_of_
assert ti.task.get_extra_links(ti, DataprocClusterLink.name) ==
DATAPROC_CLUSTER_LINK_EXPECTED
+class TestDataprocStartClusterOperator(DataprocClusterTestBase):
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute(self, mock_hook, mock_to_dict):
+ cluster = MagicMock()
+ cluster.status.State.RUNNING = 3
+ cluster.status.state = 0
+ mock_hook.return_value.get_cluster.return_value = cluster
+
+ op = DataprocStartClusterOperator(
+ task_id=TASK_ID,
+ cluster_name=CLUSTER_NAME,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context=self.mock_context)
+
+ mock_hook.return_value.get_cluster.assert_called_with(
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.start_cluster.assert_called_once_with(
+ cluster_name=CLUSTER_NAME,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_uuid=None,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestDataprocStopClusterOperator(DataprocClusterTestBase):
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute(self, mock_hook, mock_to_dict):
+ cluster = MagicMock()
+ cluster.status.State.STOPPED = 4
+ cluster.status.state = 0
+ mock_hook.return_value.get_cluster.return_value = cluster
+
+ op = DataprocStopClusterOperator(
+ task_id=TASK_ID,
+ cluster_name=CLUSTER_NAME,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ request_id=REQUEST_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ op.execute(context=self.mock_context)
+
+ mock_hook.return_value.get_cluster.assert_called_with(
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ mock_hook.return_value.stop_cluster.assert_called_once_with(
+ cluster_name=CLUSTER_NAME,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_uuid=None,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
class TestDataprocInstantiateWorkflowTemplateOperator:
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
diff --git
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
new file mode 100644
index 0000000000..6a77a14684
--- /dev/null
+++
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocCreateClusterOperator in case of the cluster
is already existing and stopped.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocStartClusterOperator,
+ DataprocStopClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "example_dataproc_cluster_create_existing_stopped_cluster"
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS")
+
+CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-")
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+}
+
+with DAG(
+ DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False,
tags=["dataproc", "example"]
+) as dag:
+ create_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ use_if_exists=True,
+ )
+
+ start_cluster = DataprocStartClusterOperator(
+ task_id="start_cluster",
+ project_id=PROJECT_ID,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ )
+
+ stop_cluster = DataprocStopClusterOperator(
+ task_id="stop_cluster",
+ project_id=PROJECT_ID,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ )
+
+ create_cluster_for_stopped_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster_for_stopped_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ use_if_exists=True,
+ )
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ cluster_name=CLUSTER_NAME,
+ region=REGION,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_cluster
+ >> stop_cluster
+ >> start_cluster
+ # TEST BODY
+ >> create_cluster_for_stopped_cluster
+ # TEST TEARDOWN
+ >> delete_cluster
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
new file mode 100644
index 0000000000..7dcb127cd6
--- /dev/null
+++
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocStartClusterOperator and
DataprocStopClusterOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+ DataprocCreateClusterOperator,
+ DataprocDeleteClusterOperator,
+ DataprocStartClusterOperator,
+ DataprocStopClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "dataproc_cluster_start_stop"
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS")
+
+CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-")
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+ "master_config": {
+ "num_instances": 1,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri": "n1-standard-4",
+ "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb":
32},
+ },
+}
+
+with DAG(
+ DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False,
tags=["dataproc", "example"]
+) as dag:
+ create_cluster = DataprocCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ cluster_config=CLUSTER_CONFIG,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ use_if_exists=True,
+ )
+
+ # [START how_to_cloud_dataproc_start_cluster_operator]
+ start_cluster = DataprocStartClusterOperator(
+ task_id="start_cluster",
+ project_id=PROJECT_ID,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ )
+ # [END how_to_cloud_dataproc_start_cluster_operator]
+
+ # [START how_to_cloud_dataproc_stop_cluster_operator]
+ stop_cluster = DataprocStopClusterOperator(
+ task_id="stop_cluster",
+ project_id=PROJECT_ID,
+ region=REGION,
+ cluster_name=CLUSTER_NAME,
+ )
+ # [END how_to_cloud_dataproc_stop_cluster_operator]
+
+ delete_cluster = DataprocDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ cluster_name=CLUSTER_NAME,
+ region=REGION,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ # TEST SETUP
+ create_cluster
+ # TEST BODY
+ >> stop_cluster
+ >> start_cluster
+ # TEST TEARDOWN
+ >> delete_cluster
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)