This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 08d15d06ba Add support for driver pool, instance flexibility policy, 
and min_num_instances for Dataproc (#34172)
08d15d06ba is described below

commit 08d15d06ba8675d70fcbd19f0500d67fc5f310cd
Author: Ahzaz Hingora <[email protected]>
AuthorDate: Thu Nov 16 16:51:52 2023 +0530

    Add support for driver pool, instance flexibility policy, and 
min_num_instances for Dataproc (#34172)
---
 .../providers/google/cloud/operators/dataproc.py   |  85 +++++++++++++
 airflow/providers/google/provider.yaml             |   2 +-
 docs/spelling_wordlist.txt                         |   2 +
 .../google/cloud/operators/test_dataproc.py        | 140 +++++++++++++++++++++
 4 files changed, 228 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 8d3387a700..b489a79dc8 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -25,6 +25,7 @@ import re
 import time
 import uuid
 import warnings
+from dataclasses import dataclass
 from datetime import datetime, timedelta
 from enum import Enum
 from typing import TYPE_CHECKING, Any, Sequence
@@ -77,6 +78,38 @@ class PreemptibilityType(Enum):
     NON_PREEMPTIBLE = "NON_PREEMPTIBLE"
 
 
+@dataclass
+class InstanceSelection:
+    """Defines machines types and a rank to which the machines types belong.
+
+    Representation for
+    
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.InstanceSelection.
+
+    :param machine_types: Full machine-type names, e.g. "n1-standard-16".
+    :param rank: Preference of this instance selection. Lower number means 
higher preference.
+        Dataproc will first try to create a VM based on the machine-type with 
priority rank and fallback
+        to next rank based on availability. Machine types and instance 
selections with the same priority have
+        the same preference.
+    """
+
+    machine_types: list[str]
+    rank: int = 0
+
+
+@dataclass
+class InstanceFlexibilityPolicy:
+    """
+    Instance flexibility Policy allowing a mixture of VM shapes and 
provisioning models.
+
+    Representation for 
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.
+
+    :param instance_selection_list: List of instance selection options that 
the group will use when
+        creating new VMs.
+    """
+
+    instance_selection_list: list[InstanceSelection]
+
+
 class ClusterGenerator:
     """Create a new Dataproc Cluster.
 
@@ -85,6 +118,11 @@ class ClusterGenerator:
         to create the cluster. (templated)
     :param num_workers: The # of workers to spin up. If set to zero will
         spin up cluster in a single node mode
+    :param min_num_workers: The minimum number of primary worker instances to 
create.
+        If more than ``min_num_workers`` VMs are created out of 
``num_workers``, the failed VMs will be
+        deleted, cluster is resized to available VMs and set to RUNNING.
+        If created VMs are less than ``min_num_workers``, the cluster is 
placed in ERROR state. The failed
+        VMs are not deleted.
     :param storage_bucket: The storage bucket to use, setting to None lets 
dataproc
         generate a custom one for you
     :param init_actions_uris: List of GCS uri's containing
@@ -153,12 +191,18 @@ class ClusterGenerator:
         
``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]``
 # noqa
     :param enable_component_gateway: Provides access to the web interfaces of 
default and selected optional
         components on the cluster.
+    :param driver_pool_size: The number of driver nodes in the node group.
+    :param driver_pool_id: The ID for the driver pool. Must be unique within 
the cluster. Use this ID to
+        identify the driver group in future operations, such as resizing the 
node group.
+    :param secondary_worker_instance_flexibility_policy: Instance flexibility 
Policy allowing a mixture of VM
+        shapes and provisioning models.
     """
 
     def __init__(
         self,
         project_id: str,
         num_workers: int | None = None,
+        min_num_workers: int | None = None,
         zone: str | None = None,
         network_uri: str | None = None,
         subnetwork_uri: str | None = None,
@@ -191,11 +235,15 @@ class ClusterGenerator:
         auto_delete_ttl: int | None = None,
         customer_managed_key: str | None = None,
         enable_component_gateway: bool | None = False,
+        driver_pool_size: int = 0,
+        driver_pool_id: str | None = None,
+        secondary_worker_instance_flexibility_policy: 
InstanceFlexibilityPolicy | None = None,
         **kwargs,
     ) -> None:
         self.project_id = project_id
         self.num_masters = num_masters
         self.num_workers = num_workers
+        self.min_num_workers = min_num_workers
         self.num_preemptible_workers = num_preemptible_workers
         self.preemptibility = self._set_preemptibility_type(preemptibility)
         self.storage_bucket = storage_bucket
@@ -228,6 +276,9 @@ class ClusterGenerator:
         self.customer_managed_key = customer_managed_key
         self.enable_component_gateway = enable_component_gateway
         self.single_node = num_workers == 0
+        self.driver_pool_size = driver_pool_size
+        self.driver_pool_id = driver_pool_id
+        self.secondary_worker_instance_flexibility_policy = 
secondary_worker_instance_flexibility_policy
 
         if self.custom_image and self.image_version:
             raise ValueError("The custom_image and image_version can't be both 
set")
@@ -241,6 +292,15 @@ class ClusterGenerator:
         if self.single_node and self.num_preemptible_workers > 0:
             raise ValueError("Single node cannot have preemptible workers.")
 
+        if self.min_num_workers:
+            if not self.num_workers:
+                raise ValueError("Must specify num_workers when 
min_num_workers are provided.")
+            if self.min_num_workers > self.num_workers:
+                raise ValueError(
+                    "The value of min_num_workers must be less than or equal 
to num_workers. "
+                    f"Provided {self.min_num_workers}(min_num_workers) and 
{self.num_workers}(num_workers)."
+                )
+
     def _set_preemptibility_type(self, preemptibility: str):
         return PreemptibilityType(preemptibility.upper())
 
@@ -307,6 +367,17 @@ class ClusterGenerator:
 
         return cluster_data
 
+    def _build_driver_pool(self):
+        driver_pool = {
+            "node_group": {
+                "roles": ["DRIVER"],
+                "node_group_config": {"num_instances": self.driver_pool_size},
+            },
+        }
+        if self.driver_pool_id:
+            driver_pool["node_group_id"] = self.driver_pool_id
+        return driver_pool
+
     def _build_cluster_data(self):
         if self.zone:
             master_type_uri = (
@@ -344,6 +415,10 @@ class ClusterGenerator:
             "autoscaling_config": {},
             "endpoint_config": {},
         }
+
+        if self.min_num_workers:
+            cluster_data["worker_config"]["min_num_instances"] = 
self.min_num_workers
+
         if self.num_preemptible_workers > 0:
             cluster_data["secondary_worker_config"] = {
                 "num_instances": self.num_preemptible_workers,
@@ -355,6 +430,13 @@ class ClusterGenerator:
                 "is_preemptible": True,
                 "preemptibility": self.preemptibility.value,
             }
+            if self.secondary_worker_instance_flexibility_policy:
+                
cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = {
+                    "instance_selection_list": [
+                        vars(s)
+                        for s in 
self.secondary_worker_instance_flexibility_policy.instance_selection_list
+                    ]
+                }
 
         if self.storage_bucket:
             cluster_data["config_bucket"] = self.storage_bucket
@@ -382,6 +464,9 @@ class ClusterGenerator:
             if not self.single_node:
                 cluster_data["worker_config"]["image_uri"] = custom_image_url
 
+        if self.driver_pool_size > 0:
+            cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]
+
         cluster_data = self._build_gce_cluster_config(cluster_data)
 
         if self.single_node:
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index 0500304a0d..4287801d2b 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -102,7 +102,7 @@ dependencies:
   - google-cloud-dataflow-client>=0.8.2
   - google-cloud-dataform>=0.5.0
   - google-cloud-dataplex>=1.4.2
-  - google-cloud-dataproc>=5.4.0
+  - google-cloud-dataproc>=5.5.0
   - google-cloud-dataproc-metastore>=1.12.0
   - google-cloud-dlp>=3.12.0
   - google-cloud-kms>=2.15.0
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index b787191fd0..c56f81ebaf 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -792,7 +792,9 @@ InspectContentResponse
 InspectTemplate
 instafail
 installable
+InstanceFlexibilityPolicy
 InstanceGroupConfig
+InstanceSelection
 instanceTemplates
 instantiation
 integrations
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 39361c5a98..59a9c1008c 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -60,6 +60,8 @@ from airflow.providers.google.cloud.operators.dataproc import 
(
     DataprocSubmitSparkJobOperator,
     DataprocSubmitSparkSqlJobOperator,
     DataprocUpdateClusterOperator,
+    InstanceFlexibilityPolicy,
+    InstanceSelection,
 )
 from airflow.providers.google.cloud.triggers.dataproc import (
     DataprocBatchTrigger,
@@ -112,6 +114,7 @@ CONFIG = {
         "disk_config": {"boot_disk_type": "worker_disk_type", 
"boot_disk_size_gb": 256},
         "image_uri": "https://www.googleapis.com/compute/beta/projects/";
         "custom_image_project_id/global/images/custom_image",
+        "min_num_instances": 1,
     },
     "secondary_worker_config": {
         "num_instances": 4,
@@ -132,6 +135,17 @@ CONFIG = {
         {"executable_file": "init_actions_uris", "execution_timeout": 
{"seconds": 600}}
     ],
     "endpoint_config": {},
+    "auxiliary_node_groups": [
+        {
+            "node_group": {
+                "roles": ["DRIVER"],
+                "node_group_config": {
+                    "num_instances": 2,
+                },
+            },
+            "node_group_id": "cluster_driver_pool",
+        }
+    ],
 }
 VIRTUAL_CLUSTER_CONFIG = {
     "kubernetes_cluster_config": {
@@ -197,6 +211,64 @@ CONFIG_WITH_CUSTOM_IMAGE_FAMILY = {
     },
 }
 
+CONFIG_WITH_FLEX_MIG = {
+    "gce_cluster_config": {
+        "zone_uri": 
"https://www.googleapis.com/compute/v1/projects/project_id/zones/zone";,
+        "metadata": {"metadata": "data"},
+        "network_uri": "network_uri",
+        "subnetwork_uri": "subnetwork_uri",
+        "internal_ip_only": True,
+        "tags": ["tags"],
+        "service_account": "service_account",
+        "service_account_scopes": ["service_account_scopes"],
+    },
+    "master_config": {
+        "num_instances": 2,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/master_machine_type",
+        "disk_config": {"boot_disk_type": "master_disk_type", 
"boot_disk_size_gb": 128},
+        "image_uri": "https://www.googleapis.com/compute/beta/projects/";
+        "custom_image_project_id/global/images/custom_image",
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+        "disk_config": {"boot_disk_type": "worker_disk_type", 
"boot_disk_size_gb": 256},
+        "image_uri": "https://www.googleapis.com/compute/beta/projects/";
+        "custom_image_project_id/global/images/custom_image",
+    },
+    "secondary_worker_config": {
+        "num_instances": 4,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+        "disk_config": {"boot_disk_type": "worker_disk_type", 
"boot_disk_size_gb": 256},
+        "is_preemptible": True,
+        "preemptibility": "SPOT",
+        "instance_flexibility_policy": {
+            "instance_selection_list": [
+                {
+                    "machine_types": [
+                        "projects/project_id/zones/zone/machineTypes/machine1",
+                        "projects/project_id/zones/zone/machineTypes/machine2",
+                    ],
+                    "rank": 0,
+                },
+                {"machine_types": 
["projects/project_id/zones/zone/machineTypes/machine3"], "rank": 1},
+            ],
+        },
+    },
+    "software_config": {"properties": {"properties": "data"}, 
"optional_components": ["optional_components"]},
+    "lifecycle_config": {
+        "idle_delete_ttl": {"seconds": 60},
+        "auto_delete_time": "2019-09-12T00:00:00.000000Z",
+    },
+    "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
+    "autoscaling_config": {"policy_uri": "autoscaling_policy"},
+    "config_bucket": "storage_bucket",
+    "initialization_actions": [
+        {"executable_file": "init_actions_uris", "execution_timeout": 
{"seconds": 600}}
+    ],
+    "endpoint_config": {},
+}
+
 LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}
 
 LABELS.update({"airflow-version": "v" + airflow_version.replace(".", 
"-").replace("+", "-")})
@@ -361,10 +433,26 @@ class TestsClusterGenerator:
             )
             assert "num_workers == 0 means single" in str(ctx.value)
 
+    def test_min_num_workers_less_than_num_workers(self):
+        with pytest.raises(ValueError) as ctx:
+            ClusterGenerator(
+                num_workers=3, min_num_workers=4, project_id=GCP_PROJECT, 
cluster_name=CLUSTER_NAME
+            )
+            assert (
+                "The value of min_num_workers must be less than or equal to 
num_workers. "
+                "Provided 4(min_num_workers) and 3(num_workers)." in 
str(ctx.value)
+            )
+
+    def test_min_num_workers_without_num_workers(self):
+        with pytest.raises(ValueError) as ctx:
+            ClusterGenerator(min_num_workers=4, project_id=GCP_PROJECT, 
cluster_name=CLUSTER_NAME)
+            assert "Must specify num_workers when min_num_workers are 
provided." in str(ctx.value)
+
     def test_build(self):
         generator = ClusterGenerator(
             project_id="project_id",
             num_workers=2,
+            min_num_workers=1,
             zone="zone",
             network_uri="network_uri",
             subnetwork_uri="subnetwork_uri",
@@ -395,6 +483,8 @@ class TestsClusterGenerator:
             auto_delete_time=datetime(2019, 9, 12),
             auto_delete_ttl=250,
             customer_managed_key="customer_managed_key",
+            driver_pool_id="cluster_driver_pool",
+            driver_pool_size=2,
         )
         cluster = generator.make()
         assert CONFIG == cluster
@@ -438,6 +528,56 @@ class TestsClusterGenerator:
         cluster = generator.make()
         assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster
 
+    def test_build_with_flex_migs(self):
+        generator = ClusterGenerator(
+            project_id="project_id",
+            num_workers=2,
+            zone="zone",
+            network_uri="network_uri",
+            subnetwork_uri="subnetwork_uri",
+            internal_ip_only=True,
+            tags=["tags"],
+            storage_bucket="storage_bucket",
+            init_actions_uris=["init_actions_uris"],
+            init_action_timeout="10m",
+            metadata={"metadata": "data"},
+            custom_image="custom_image",
+            custom_image_project_id="custom_image_project_id",
+            autoscaling_policy="autoscaling_policy",
+            properties={"properties": "data"},
+            optional_components=["optional_components"],
+            num_masters=2,
+            master_machine_type="master_machine_type",
+            master_disk_type="master_disk_type",
+            master_disk_size=128,
+            worker_machine_type="worker_machine_type",
+            worker_disk_type="worker_disk_type",
+            worker_disk_size=256,
+            num_preemptible_workers=4,
+            preemptibility="Spot",
+            region="region",
+            service_account="service_account",
+            service_account_scopes=["service_account_scopes"],
+            idle_delete_ttl=60,
+            auto_delete_time=datetime(2019, 9, 12),
+            auto_delete_ttl=250,
+            customer_managed_key="customer_managed_key",
+            
secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy(
+                [
+                    InstanceSelection(
+                        [
+                            
"projects/project_id/zones/zone/machineTypes/machine1",
+                            
"projects/project_id/zones/zone/machineTypes/machine2",
+                        ],
+                        0,
+                    ),
+                    
InstanceSelection(["projects/project_id/zones/zone/machineTypes/machine3"], 1),
+                ]
+            ),
+        )
+        cluster = generator.make()
+        assert CONFIG_WITH_FLEX_MIG == cluster
+
 
 class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
     def test_deprecation_warning(self):

Reply via email to