This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 08d15d06ba Add support for driver pool, instance flexibility policy,
and min_num_instances for Dataproc (#34172)
08d15d06ba is described below
commit 08d15d06ba8675d70fcbd19f0500d67fc5f310cd
Author: Ahzaz Hingora <[email protected]>
AuthorDate: Thu Nov 16 16:51:52 2023 +0530
Add support for driver pool, instance flexibility policy, and
min_num_instances for Dataproc (#34172)
---
.../providers/google/cloud/operators/dataproc.py | 85 +++++++++++++
airflow/providers/google/provider.yaml | 2 +-
docs/spelling_wordlist.txt | 2 +
.../google/cloud/operators/test_dataproc.py | 140 +++++++++++++++++++++
4 files changed, 228 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/google/cloud/operators/dataproc.py
b/airflow/providers/google/cloud/operators/dataproc.py
index 8d3387a700..b489a79dc8 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -25,6 +25,7 @@ import re
import time
import uuid
import warnings
+from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Sequence
@@ -77,6 +78,38 @@ class PreemptibilityType(Enum):
NON_PREEMPTIBLE = "NON_PREEMPTIBLE"
+@dataclass
+class InstanceSelection:
+ """Defines machines types and a rank to which the machines types belong.
+
+ Representation for
+
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.InstanceSelection.
+
+ :param machine_types: Full machine-type names, e.g. "n1-standard-16".
+ :param rank: Preference of this instance selection. Lower number means
higher preference.
+ Dataproc will first try to create a VM based on the machine-type with
priority rank and fallback
+ to next rank based on availability. Machine types and instance
selections with the same priority have
+ the same preference.
+ """
+
+ machine_types: list[str]
+ rank: int = 0
+
+
+@dataclass
+class InstanceFlexibilityPolicy:
+ """
+ Instance flexibility Policy allowing a mixture of VM shapes and
provisioning models.
+
+ Representation for
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.
+
+ :param instance_selection_list: List of instance selection options that
the group will use when
+ creating new VMs.
+ """
+
+ instance_selection_list: list[InstanceSelection]
+
+
class ClusterGenerator:
"""Create a new Dataproc Cluster.
@@ -85,6 +118,11 @@ class ClusterGenerator:
to create the cluster. (templated)
:param num_workers: The # of workers to spin up. If set to zero will
spin up cluster in a single node mode
+ :param min_num_workers: The minimum number of primary worker instances to
create.
+ If more than ``min_num_workers`` VMs are created out of
``num_workers``, the failed VMs will be
+ deleted, cluster is resized to available VMs and set to RUNNING.
+ If created VMs are less than ``min_num_workers``, the cluster is
placed in ERROR state. The failed
+ VMs are not deleted.
:param storage_bucket: The storage bucket to use, setting to None lets
dataproc
generate a custom one for you
:param init_actions_uris: List of GCS uri's containing
@@ -153,12 +191,18 @@ class ClusterGenerator:
``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]``
# noqa
:param enable_component_gateway: Provides access to the web interfaces of
default and selected optional
components on the cluster.
+ :param driver_pool_size: The number of driver nodes in the node group.
+ :param driver_pool_id: The ID for the driver pool. Must be unique within
the cluster. Use this ID to
+ identify the driver group in future operations, such as resizing the
node group.
+ :param secondary_worker_instance_flexibility_policy: Instance flexibility
Policy allowing a mixture of VM
+ shapes and provisioning models.
"""
def __init__(
self,
project_id: str,
num_workers: int | None = None,
+ min_num_workers: int | None = None,
zone: str | None = None,
network_uri: str | None = None,
subnetwork_uri: str | None = None,
@@ -191,11 +235,15 @@ class ClusterGenerator:
auto_delete_ttl: int | None = None,
customer_managed_key: str | None = None,
enable_component_gateway: bool | None = False,
+ driver_pool_size: int = 0,
+ driver_pool_id: str | None = None,
+ secondary_worker_instance_flexibility_policy:
InstanceFlexibilityPolicy | None = None,
**kwargs,
) -> None:
self.project_id = project_id
self.num_masters = num_masters
self.num_workers = num_workers
+ self.min_num_workers = min_num_workers
self.num_preemptible_workers = num_preemptible_workers
self.preemptibility = self._set_preemptibility_type(preemptibility)
self.storage_bucket = storage_bucket
@@ -228,6 +276,9 @@ class ClusterGenerator:
self.customer_managed_key = customer_managed_key
self.enable_component_gateway = enable_component_gateway
self.single_node = num_workers == 0
+ self.driver_pool_size = driver_pool_size
+ self.driver_pool_id = driver_pool_id
+ self.secondary_worker_instance_flexibility_policy =
secondary_worker_instance_flexibility_policy
if self.custom_image and self.image_version:
raise ValueError("The custom_image and image_version can't be both
set")
@@ -241,6 +292,15 @@ class ClusterGenerator:
if self.single_node and self.num_preemptible_workers > 0:
raise ValueError("Single node cannot have preemptible workers.")
+ if self.min_num_workers:
+ if not self.num_workers:
+ raise ValueError("Must specify num_workers when
min_num_workers are provided.")
+ if self.min_num_workers > self.num_workers:
+ raise ValueError(
+ "The value of min_num_workers must be less than or equal
to num_workers. "
+ f"Provided {self.min_num_workers}(min_num_workers) and
{self.num_workers}(num_workers)."
+ )
+
def _set_preemptibility_type(self, preemptibility: str):
return PreemptibilityType(preemptibility.upper())
@@ -307,6 +367,17 @@ class ClusterGenerator:
return cluster_data
+ def _build_driver_pool(self):
+ driver_pool = {
+ "node_group": {
+ "roles": ["DRIVER"],
+ "node_group_config": {"num_instances": self.driver_pool_size},
+ },
+ }
+ if self.driver_pool_id:
+ driver_pool["node_group_id"] = self.driver_pool_id
+ return driver_pool
+
def _build_cluster_data(self):
if self.zone:
master_type_uri = (
@@ -344,6 +415,10 @@ class ClusterGenerator:
"autoscaling_config": {},
"endpoint_config": {},
}
+
+ if self.min_num_workers:
+ cluster_data["worker_config"]["min_num_instances"] =
self.min_num_workers
+
if self.num_preemptible_workers > 0:
cluster_data["secondary_worker_config"] = {
"num_instances": self.num_preemptible_workers,
@@ -355,6 +430,13 @@ class ClusterGenerator:
"is_preemptible": True,
"preemptibility": self.preemptibility.value,
}
+ if self.secondary_worker_instance_flexibility_policy:
+
cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = {
+ "instance_selection_list": [
+ vars(s)
+ for s in
self.secondary_worker_instance_flexibility_policy.instance_selection_list
+ ]
+ }
if self.storage_bucket:
cluster_data["config_bucket"] = self.storage_bucket
@@ -382,6 +464,9 @@ class ClusterGenerator:
if not self.single_node:
cluster_data["worker_config"]["image_uri"] = custom_image_url
+ if self.driver_pool_size > 0:
+ cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]
+
cluster_data = self._build_gce_cluster_config(cluster_data)
if self.single_node:
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 0500304a0d..4287801d2b 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -102,7 +102,7 @@ dependencies:
- google-cloud-dataflow-client>=0.8.2
- google-cloud-dataform>=0.5.0
- google-cloud-dataplex>=1.4.2
- - google-cloud-dataproc>=5.4.0
+ - google-cloud-dataproc>=5.5.0
- google-cloud-dataproc-metastore>=1.12.0
- google-cloud-dlp>=3.12.0
- google-cloud-kms>=2.15.0
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index b787191fd0..c56f81ebaf 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -792,7 +792,9 @@ InspectContentResponse
InspectTemplate
instafail
installable
+InstanceFlexibilityPolicy
InstanceGroupConfig
+InstanceSelection
instanceTemplates
instantiation
integrations
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 39361c5a98..59a9c1008c 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -60,6 +60,8 @@ from airflow.providers.google.cloud.operators.dataproc import
(
DataprocSubmitSparkJobOperator,
DataprocSubmitSparkSqlJobOperator,
DataprocUpdateClusterOperator,
+ InstanceFlexibilityPolicy,
+ InstanceSelection,
)
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
@@ -112,6 +114,7 @@ CONFIG = {
"disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/custom_image",
+ "min_num_instances": 1,
},
"secondary_worker_config": {
"num_instances": 4,
@@ -132,6 +135,17 @@ CONFIG = {
{"executable_file": "init_actions_uris", "execution_timeout":
{"seconds": 600}}
],
"endpoint_config": {},
+ "auxiliary_node_groups": [
+ {
+ "node_group": {
+ "roles": ["DRIVER"],
+ "node_group_config": {
+ "num_instances": 2,
+ },
+ },
+ "node_group_id": "cluster_driver_pool",
+ }
+ ],
}
VIRTUAL_CLUSTER_CONFIG = {
"kubernetes_cluster_config": {
@@ -197,6 +211,64 @@ CONFIG_WITH_CUSTOM_IMAGE_FAMILY = {
},
}
+CONFIG_WITH_FLEX_MIG = {
+ "gce_cluster_config": {
+ "zone_uri":
"https://www.googleapis.com/compute/v1/projects/project_id/zones/zone",
+ "metadata": {"metadata": "data"},
+ "network_uri": "network_uri",
+ "subnetwork_uri": "subnetwork_uri",
+ "internal_ip_only": True,
+ "tags": ["tags"],
+ "service_account": "service_account",
+ "service_account_scopes": ["service_account_scopes"],
+ },
+ "master_config": {
+ "num_instances": 2,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/master_machine_type",
+ "disk_config": {"boot_disk_type": "master_disk_type",
"boot_disk_size_gb": 128},
+ "image_uri": "https://www.googleapis.com/compute/beta/projects/"
+ "custom_image_project_id/global/images/custom_image",
+ },
+ "worker_config": {
+ "num_instances": 2,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+ "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256},
+ "image_uri": "https://www.googleapis.com/compute/beta/projects/"
+ "custom_image_project_id/global/images/custom_image",
+ },
+ "secondary_worker_config": {
+ "num_instances": 4,
+ "machine_type_uri":
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+ "disk_config": {"boot_disk_type": "worker_disk_type",
"boot_disk_size_gb": 256},
+ "is_preemptible": True,
+ "preemptibility": "SPOT",
+ "instance_flexibility_policy": {
+ "instance_selection_list": [
+ {
+ "machine_types": [
+ "projects/project_id/zones/zone/machineTypes/machine1",
+ "projects/project_id/zones/zone/machineTypes/machine2",
+ ],
+ "rank": 0,
+ },
+ {"machine_types":
["projects/project_id/zones/zone/machineTypes/machine3"], "rank": 1},
+ ],
+ },
+ },
+ "software_config": {"properties": {"properties": "data"},
"optional_components": ["optional_components"]},
+ "lifecycle_config": {
+ "idle_delete_ttl": {"seconds": 60},
+ "auto_delete_time": "2019-09-12T00:00:00.000000Z",
+ },
+ "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
+ "autoscaling_config": {"policy_uri": "autoscaling_policy"},
+ "config_bucket": "storage_bucket",
+ "initialization_actions": [
+ {"executable_file": "init_actions_uris", "execution_timeout":
{"seconds": 600}}
+ ],
+ "endpoint_config": {},
+}
+
LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}
LABELS.update({"airflow-version": "v" + airflow_version.replace(".",
"-").replace("+", "-")})
@@ -361,10 +433,26 @@ class TestsClusterGenerator:
)
assert "num_workers == 0 means single" in str(ctx.value)
+ def test_min_num_workers_less_than_num_workers(self):
+ with pytest.raises(ValueError) as ctx:
+ ClusterGenerator(
+ num_workers=3, min_num_workers=4, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME
+ )
+ assert (
+ "The value of min_num_workers must be less than or equal to
num_workers. "
+ "Provided 4(min_num_workers) and 3(num_workers)." in
str(ctx.value)
+ )
+
+ def test_min_num_workers_without_num_workers(self):
+ with pytest.raises(ValueError) as ctx:
+ ClusterGenerator(min_num_workers=4, project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME)
+ assert "Must specify num_workers when min_num_workers are
provided." in str(ctx.value)
+
def test_build(self):
generator = ClusterGenerator(
project_id="project_id",
num_workers=2,
+ min_num_workers=1,
zone="zone",
network_uri="network_uri",
subnetwork_uri="subnetwork_uri",
@@ -395,6 +483,8 @@ class TestsClusterGenerator:
auto_delete_time=datetime(2019, 9, 12),
auto_delete_ttl=250,
customer_managed_key="customer_managed_key",
+ driver_pool_id="cluster_driver_pool",
+ driver_pool_size=2,
)
cluster = generator.make()
assert CONFIG == cluster
@@ -438,6 +528,56 @@ class TestsClusterGenerator:
cluster = generator.make()
assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster
+ def test_build_with_flex_migs(self):
+ generator = ClusterGenerator(
+ project_id="project_id",
+ num_workers=2,
+ zone="zone",
+ network_uri="network_uri",
+ subnetwork_uri="subnetwork_uri",
+ internal_ip_only=True,
+ tags=["tags"],
+ storage_bucket="storage_bucket",
+ init_actions_uris=["init_actions_uris"],
+ init_action_timeout="10m",
+ metadata={"metadata": "data"},
+ custom_image="custom_image",
+ custom_image_project_id="custom_image_project_id",
+ autoscaling_policy="autoscaling_policy",
+ properties={"properties": "data"},
+ optional_components=["optional_components"],
+ num_masters=2,
+ master_machine_type="master_machine_type",
+ master_disk_type="master_disk_type",
+ master_disk_size=128,
+ worker_machine_type="worker_machine_type",
+ worker_disk_type="worker_disk_type",
+ worker_disk_size=256,
+ num_preemptible_workers=4,
+ preemptibility="Spot",
+ region="region",
+ service_account="service_account",
+ service_account_scopes=["service_account_scopes"],
+ idle_delete_ttl=60,
+ auto_delete_time=datetime(2019, 9, 12),
+ auto_delete_ttl=250,
+ customer_managed_key="customer_managed_key",
+
secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy(
+ [
+ InstanceSelection(
+ [
+
"projects/project_id/zones/zone/machineTypes/machine1",
+
"projects/project_id/zones/zone/machineTypes/machine2",
+ ],
+ 0,
+ ),
+
InstanceSelection(["projects/project_id/zones/zone/machineTypes/machine3"], 1),
+ ]
+ ),
+ )
+ cluster = generator.make()
+ assert CONFIG_WITH_FLEX_MIG == cluster
+
class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):