This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df132b2dd6 Add GKEStartKueueInsideClusterOperator (#37072)
df132b2dd6 is described below
commit df132b2dd6fcb9022e1ff5f28841bec7a120853b
Author: VladaZakharova <[email protected]>
AuthorDate: Thu Feb 15 14:54:30 2024 +0100
Add GKEStartKueueInsideClusterOperator (#37072)
---
.../providers/cncf/kubernetes/hooks/kubernetes.py | 24 ++-
.../google/cloud/hooks/kubernetes_engine.py | 163 ++++++++++++++++-
.../google/cloud/operators/kubernetes_engine.py | 191 +++++++++++++++++++-
.../operators/cloud/kubernetes_engine.rst | 19 ++
docs/spelling_wordlist.txt | 5 +
.../google/cloud/hooks/test_kubernetes_engine.py | 189 ++++++++++++++++++++
.../cloud/operators/test_kubernetes_engine.py | 193 ++++++++++++++++++++-
.../example_kubernetes_engine_kueue.py | 84 +++++++++
8 files changed, 850 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 18495d0a38..e6e1054f10 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -37,7 +37,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager
import PodOperatorHookP
from airflow.utils import yaml
if TYPE_CHECKING:
- from kubernetes.client.models import V1Pod
+ from kubernetes.client.models import V1Deployment, V1Pod
LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file
kube_config from {}..."
@@ -282,6 +282,10 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(api_client=self.api_client)
+ @cached_property
+ def apps_v1_client(self) -> client.AppsV1Api:
+ return client.AppsV1Api(api_client=self.api_client)
+
@cached_property
def custom_object_client(self) -> client.CustomObjectsApi:
return client.CustomObjectsApi(api_client=self.api_client)
@@ -450,6 +454,24 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
**kwargs,
)
+ def get_deployment_status(
+ self,
+ name: str,
+ namespace: str = "default",
+ **kwargs,
+ ) -> V1Deployment:
+ """Get status of existing Deployment.
+
+ :param name: Name of Deployment to retrieve
+ :param namespace: Deployment namespace
+ """
+ try:
+ return self.apps_v1_client.read_namespaced_deployment_status(
+ name=name, namespace=namespace, pretty=True, **kwargs
+ )
+ except Exception as exc:
+ raise exc
+
def _get_bool(val) -> bool | None:
"""Convert val to bool if can be done with certainty; if we cannot infer
intention we return None."""
diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index e4ec748116..0878fa6b6f 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -15,14 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-This module contains a Google Kubernetes Engine Hook.
-
-.. spelling:word-list::
-
- gapic
- enums
-"""
+"""This module contains a Google Kubernetes Engine Hook."""
from __future__ import annotations
import contextlib
@@ -41,13 +34,15 @@ from google.auth.transport import requests as
google_requests
from google.cloud import container_v1, exceptions # type: ignore[attr-defined]
from google.cloud.container_v1 import ClusterManagerAsyncClient,
ClusterManagerClient
from google.cloud.container_v1.types import Cluster, Operation
-from kubernetes import client
+from kubernetes import client, utils
+from kubernetes.client.models import V1Deployment
from kubernetes_asyncio import client as async_client
from kubernetes_asyncio.config.kube_config import FileOrData
from urllib3.exceptions import HTTPError
from airflow import version
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kube_client import _enable_tcp_keepalive
from airflow.providers.cncf.kubernetes.utils.pod_manager import
PodOperatorHookProtocol
from airflow.providers.google.common.consts import CLIENT_INFO
@@ -299,6 +294,130 @@ class GKEHook(GoogleBaseHook):
timeout=timeout,
)
+ def check_cluster_autoscaling_ability(self, cluster: Cluster | dict):
+ """
+ Check if the specified Cluster has ability to autoscale.
+
+ Cluster should be Autopilot, with Node Auto-provisioning or regular
auto-scaled node pools.
+ Returns True if the Cluster supports autoscaling, otherwise returns
False.
+
+ :param cluster: The Cluster object.
+ """
+ if isinstance(cluster, Cluster):
+ cluster_dict_representation = Cluster.to_dict(cluster)
+ elif not isinstance(cluster, dict):
+ raise AirflowException("cluster is not instance of Cluster proto
or python dict")
+ else:
+ cluster_dict_representation = cluster
+
+ node_pools_autoscaled = False
+ for node_pool in cluster_dict_representation["node_pools"]:
+ try:
+ if node_pool["autoscaling"]["enabled"] is True:
+ node_pools_autoscaled = True
+ break
+ except KeyError:
+ self.log.info("No autoscaling enabled in Node pools level.")
+ break
+ if (
+ cluster_dict_representation["autopilot"]["enabled"]
+ or
cluster_dict_representation["autoscaling"]["enable_node_autoprovisioning"]
+ or node_pools_autoscaled
+ ):
+ return True
+ else:
+ return False
+
+
+class GKEDeploymentHook(GoogleBaseHook, KubernetesHook):
+ """Google Kubernetes Engine Deployment APIs."""
+
+ def __init__(
+ self,
+ cluster_url: str,
+ ssl_ca_cert: str,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self._cluster_url = cluster_url
+ self._ssl_ca_cert = ssl_ca_cert
+
+ @cached_property
+ def api_client(self) -> client.ApiClient:
+ return self.get_conn()
+
+ @cached_property
+ def core_v1_client(self) -> client.CoreV1Api:
+ return client.CoreV1Api(self.api_client)
+
+ @cached_property
+ def batch_v1_client(self) -> client.BatchV1Api:
+ return client.BatchV1Api(self.api_client)
+
+ @cached_property
+ def apps_v1_client(self) -> client.AppsV1Api:
+ return client.AppsV1Api(api_client=self.api_client)
+
+ def get_conn(self) -> client.ApiClient:
+ configuration = self._get_config()
+ configuration.refresh_api_key_hook = self._refresh_api_key_hook
+ return client.ApiClient(configuration)
+
+ def _refresh_api_key_hook(self, configuration:
client.configuration.Configuration):
+ configuration.api_key = {"authorization":
self._get_token(self.get_credentials())}
+
+ def _get_config(self) -> client.configuration.Configuration:
+ configuration = client.Configuration(
+ host=self._cluster_url,
+ api_key_prefix={"authorization": "Bearer"},
+ api_key={"authorization": self._get_token(self.get_credentials())},
+ )
+ configuration.ssl_ca_cert = FileOrData(
+ {
+ "certificate-authority-data": self._ssl_ca_cert,
+ },
+ file_key_name="certificate-authority",
+ ).as_file()
+ return configuration
+
+ @staticmethod
+ def _get_token(creds: google.auth.credentials.Credentials) -> str:
+ if creds.token is None or creds.expired:
+ auth_req = google_requests.Request()
+ creds.refresh(auth_req)
+ return creds.token
+
+ def check_kueue_deployment_running(self, name, namespace):
+ timeout = 300
+ polling_period_seconds = 2
+
+ while timeout is None or timeout > 0:
+ try:
+ deployment = self.get_deployment_status(name=name,
namespace=namespace)
+ deployment_status = V1Deployment.to_dict(deployment)["status"]
+ replicas = deployment_status["replicas"]
+ ready_replicas = deployment_status["ready_replicas"]
+ unavailable_replicas =
deployment_status["unavailable_replicas"]
+ if (
+ replicas is not None
+ and ready_replicas is not None
+ and unavailable_replicas is None
+ and replicas == ready_replicas
+ ):
+ return
+ else:
+ self.log.info("Waiting until Deployment will be ready...")
+ time.sleep(polling_period_seconds)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for
Deployment status.")
+ raise e
+
+ if timeout is not None:
+ timeout -= polling_period_seconds
+
+ raise AirflowException("Deployment timed out")
+
class GKEAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous client of GKE."""
@@ -431,6 +550,32 @@ class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
creds.refresh(auth_req)
return creds.token
+ def apply_from_yaml_file(
+ self,
+ yaml_file: str | None = None,
+ yaml_objects: list[dict] | None = None,
+ verbose: bool = False,
+ namespace: str = "default",
+ ):
+ """
+ Perform an action from a yaml file on a Pod.
+
+ :param yaml_file: Contains the path to yaml file.
+ :param yaml_objects: List of YAML objects; used instead of reading the
yaml_file.
+ :param verbose: If True, print confirmation from create action.
Default is False.
+ :param namespace: Contains the namespace to create all resources
inside. The namespace must
+ preexist otherwise the resource creation will fail.
+ """
+ k8s_client = self.get_conn()
+
+ utils.create_from_yaml(
+ k8s_client=k8s_client,
+ yaml_objects=yaml_objects,
+ yaml_file=yaml_file,
+ verbose=verbose,
+ namespace=namespace,
+ )
+
def get_pod(self, name: str, namespace: str) -> V1Pod:
"""Get a pod object.
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 6538542f21..e063bc7bdc 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -18,19 +18,23 @@
"""This module contains Google Kubernetes Engine operators."""
from __future__ import annotations
+import re
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+import requests
+import yaml
from deprecated import deprecated
from google.api_core.exceptions import AlreadyExists
from google.cloud.container_v1.types import Cluster
+from kubernetes.utils.create_from_yaml import FailToCreateError
from airflow.configuration import conf
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
-from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook,
GKEPodHook
+from airflow.providers.google.cloud.hooks.kubernetes_engine import
GKEDeploymentHook, GKEHook, GKEPodHook
from airflow.providers.google.cloud.links.kubernetes_engine import (
KubernetesEngineClusterLink,
KubernetesEnginePodLink,
@@ -47,6 +51,45 @@ if TYPE_CHECKING:
KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
+class GKEClusterAuthDetails:
+ """
+ Helper for fetching information about cluster for connecting.
+
+ :param cluster_name: The name of the Google Kubernetes Engine cluster the
pod should be spawned in.
+ :param project_id: The Google Developers Console project id.
+ :param use_internal_ip: Use the internal IP address as the endpoint.
+ :param cluster_hook: airflow hook for working with kubernetes cluster.
+ """
+
+ def __init__(
+ self,
+ cluster_name,
+ project_id,
+ use_internal_ip,
+ cluster_hook,
+ ):
+ self.cluster_name = cluster_name
+ self.project_id = project_id
+ self.use_internal_ip = use_internal_ip
+ self.cluster_hook = cluster_hook
+ self._cluster_url = None
+ self._ssl_ca_cert = None
+
+ def fetch_cluster_info(self) -> tuple[str, str | None]:
+ """Fetch cluster info for connecting to it."""
+ cluster = self.cluster_hook.get_cluster(
+ name=self.cluster_name,
+ project_id=self.project_id,
+ )
+
+ if not self.use_internal_ip:
+ self._cluster_url = f"https://{cluster.endpoint}"
+ else:
+ self._cluster_url =
f"https://{cluster.private_cluster_config.private_endpoint}"
+ self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate
+ return self._cluster_url, self._ssl_ca_cert
+
+
class GKEDeleteClusterOperator(GoogleCloudBaseOperator):
"""
Deletes the cluster, including the Kubernetes endpoint and all worker
nodes.
@@ -388,6 +431,152 @@ class GKECreateClusterOperator(GoogleCloudBaseOperator):
return self._hook
+class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator):
+ """
+ Installs Kueue of specific version inside Cluster.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:GKEStartKueueInsideClusterOperator`
+
+ .. seealso::
+ For more details about Kueue have a look at the reference:
+ https://kueue.sigs.k8s.io/docs/overview/
+
+ :param project_id: The Google Developers Console [project ID or project
number].
+ :param location: The name of the Google Kubernetes Engine zone or region
in which the cluster resides.
+ :param cluster_name: The Cluster name in which to install Kueue.
+ :param kueue_version: Version of Kueue to install.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = (
+ "project_id",
+ "location",
+ "kueue_version",
+ "cluster_name",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+ operator_extra_links = (KubernetesEngineClusterLink(),)
+
+ def __init__(
+ self,
+ *,
+ location: str,
+ cluster_name: str,
+ kueue_version: str,
+ use_internal_ip: bool = False,
+ project_id: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.location = location
+ self.cluster_name = cluster_name
+ self.kueue_version = kueue_version
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.use_internal_ip = use_internal_ip
+ self._kueue_yaml_url = (
+
f"https://github.com/kubernetes-sigs/kueue/releases/download/{self.kueue_version}/manifests.yaml"
+ )
+
+ @cached_property
+ def cluster_hook(self) -> GKEHook:
+ return GKEHook(
+ gcp_conn_id=self.gcp_conn_id,
+ location=self.location,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ @cached_property
+ def deployment_hook(self) -> GKEDeploymentHook:
+ if self._cluster_url is None or self._ssl_ca_cert is None:
+ raise AttributeError(
+ "Cluster url and ssl_ca_cert should be defined before using
self.hook method. "
+ "Try to use self.get_kube_creds method",
+ )
+ return GKEDeploymentHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ cluster_url=self._cluster_url,
+ ssl_ca_cert=self._ssl_ca_cert,
+ )
+
+ @cached_property
+ def pod_hook(self) -> GKEPodHook:
+ if self._cluster_url is None or self._ssl_ca_cert is None:
+ raise AttributeError(
+ "Cluster url and ssl_ca_cert should be defined before using
self.hook method. "
+ "Try to use self.get_kube_creds method",
+ )
+ return GKEPodHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ cluster_url=self._cluster_url,
+ ssl_ca_cert=self._ssl_ca_cert,
+ )
+
+ @staticmethod
+ def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
+ """Download content of YAML file and separate it into several
dictionaries."""
+ response = requests.get(kueue_yaml_url, allow_redirects=True)
+ yaml_dicts = []
+ if response.status_code == 200:
+ yaml_data = response.text
+ documents = re.split(r"---\n", yaml_data)
+
+ for document in documents:
+ document_dict = yaml.safe_load(document)
+ yaml_dicts.append(document_dict)
+ else:
+ raise AirflowException("Was not able to read the yaml file from
given URL")
+ return yaml_dicts
+
+ def execute(self, context: Context):
+ self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ use_internal_ip=self.use_internal_ip,
+ cluster_hook=self.cluster_hook,
+ ).fetch_cluster_info()
+
+ cluster = self.cluster_hook.get_cluster(
+ name=self.cluster_name,
+ project_id=self.project_id,
+ )
+ KubernetesEngineClusterLink.persist(context=context,
task_instance=self, cluster=cluster)
+
+ yaml_objects =
self._get_yaml_content_from_file(kueue_yaml_url=self._kueue_yaml_url)
+
+ if
self.cluster_hook.check_cluster_autoscaling_ability(cluster=cluster):
+ try:
+ self.pod_hook.apply_from_yaml_file(yaml_objects=yaml_objects)
+
+ self.deployment_hook.check_kueue_deployment_running(
+ name="kueue-controller-manager", namespace="kueue-system"
+ )
+
+ self.log.info("Kueue installed successfully!")
+ except FailToCreateError:
+ self.log.info("Kueue is already enabled for the cluster")
+ else:
+ self.log.info(
+ "Cluster doesn't have ability to autoscale, will not install
Kueue inside. Aborting"
+ )
+
+
class GKEStartPodOperator(KubernetesPodOperator):
"""
Executes a task in a Kubernetes pod in the specified Google Kubernetes
Engine cluster.
diff --git
a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
index 74c396c985..7663ce4811 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst
@@ -71,6 +71,25 @@ lot less resources wasted on idle Operators or Sensors:
:end-before: [END howto_operator_gke_create_cluster_async]
+.. _howto/operator:GKEStartKueueInsideClusterOperator:
+
+Install Kueue of specific version inside Cluster
+""""""""""""""""""""""""""""""""""""""""""""""""
+
+Kueue is a Cloud Native Job scheduler that works with the default Kubernetes
scheduler, the Job controller,
+and the cluster autoscaler to provide an end-to-end batch system. Kueue
implements Job queueing, deciding when
+Jobs should wait and when they should start, based on quotas and a hierarchy
for sharing resources fairly among teams.
+Kueue supports Autopilot clusters, Standard GKE with Node Auto-provisioning
and regular autoscaled node pools.
+To install and use Kueue on your cluster with the help of
+:class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartKueueInsideClusterOperator`
+as shown in this example:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
+ :language: python
+ :start-after: [START howto_operator_gke_install_kueue]
+ :end-before: [END howto_operator_gke_install_kueue]
+
+
.. _howto/operator:GKEDeleteClusterOperator:
Delete GKE cluster
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 2d168708a8..f04918bf39 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -114,9 +114,11 @@ autoenv
autogenerated
automl
AutoMlClient
+autoprovisioning
autorestart
Autoscale
autoscale
+autoscaled
autoscaler
autoscaling
avp
@@ -553,6 +555,7 @@ entrypoint
entrypoints
Enum
enum
+enums
Env
env
envFrom
@@ -907,6 +910,7 @@ kubeconfig
Kubernetes
kubernetes
KubernetesPodOperator
+Kueue
Kusto
kv
kwarg
@@ -1285,6 +1289,7 @@ QuboleCheckHook
Quboles
queryParameters
querystring
+queueing
quickstart
quotechar
rabbitmq
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index fae3db76e9..06cbea84bc 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -24,10 +24,12 @@ import kubernetes.client
import pytest
from google.cloud.container_v1 import ClusterManagerAsyncClient
from google.cloud.container_v1.types import Cluster
+from kubernetes.client.models import V1Deployment, V1DeploymentStatus
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.kubernetes_engine import (
GKEAsyncHook,
+ GKEDeploymentHook,
GKEHook,
GKEPodAsyncHook,
GKEPodHook,
@@ -37,10 +39,12 @@ from tests.providers.google.cloud.utils.base_gcp_mock
import mock_base_gcp_hook_
TASK_ID = "test-gke-cluster-operator"
CLUSTER_NAME = "test-cluster"
+NAMESPACE = "test-cluster-namespace"
TEST_GCP_PROJECT_ID = "test-project"
GKE_ZONE = "test-zone"
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
GKE_STRING = "airflow.providers.google.cloud.hooks.kubernetes_engine.{}"
+K8S_HOOK = "airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook"
CLUSTER_URL = "https://path.to.cluster"
SSL_CA_CERT = "test-ssl-ca-cert"
POD_NAME = "test-pod-name"
@@ -49,6 +53,113 @@ ASYNC_HOOK_STRING = GKE_STRING.format("GKEAsyncHook")
GCP_CONN_ID = "test-gcp-conn-id"
IMPERSONATE_CHAIN = ["impersonate", "this", "test"]
OPERATION_NAME = "test-operation-name"
+CLUSTER_TEST_AUTOPILOT = {
+ "name": "autopilot-cluster",
+ "initial_node_count": 1,
+ "autopilot": {
+ "enabled": True,
+ },
+ "autoscaling": {
+ "enable_node_autoprovisioning": False,
+ },
+ "node_pools": [
+ {
+ "name": "pool",
+ "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11},
+ "initial_node_count": 2,
+ }
+ ],
+}
+CLUSTER_TEST_AUTOPROVISIONING = {
+ "name": "cluster_autoprovisioning",
+ "initial_node_count": 1,
+ "autopilot": {
+ "enabled": False,
+ },
+ "node_pools": [
+ {
+ "name": "pool",
+ "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11},
+ "initial_node_count": 2,
+ }
+ ],
+ "autoscaling": {
+ "enable_node_autoprovisioning": True,
+ "resource_limits": [
+ {"resource_type": "cpu", "maximum": 1000000000},
+ {"resource_type": "memory", "maximum": 1000000000},
+ ],
+ },
+}
+CLUSTER_TEST_AUTOSCALED = {
+ "name": "autoscaled_cluster",
+ "autopilot": {
+ "enabled": False,
+ },
+ "node_pools": [
+ {
+ "name": "autoscaled-pool",
+ "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11},
+ "initial_node_count": 2,
+ "autoscaling": {
+ "enabled": True,
+ "max_node_count": 10,
+ },
+ }
+ ],
+ "autoscaling": {
+ "enable_node_autoprovisioning": False,
+ },
+}
+
+CLUSTER_TEST_REGULAR = {
+ "name": "regular_cluster",
+ "initial_node_count": 1,
+ "autopilot": {
+ "enabled": False,
+ },
+ "node_pools": [
+ {
+ "name": "autoscaled-pool",
+ "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11},
+ "initial_node_count": 2,
+ "autoscaling": {
+ "enabled": False,
+ },
+ }
+ ],
+ "autoscaling": {
+ "enable_node_autoprovisioning": False,
+ },
+}
+pods = {
+ "succeeded": {
+ "metadata": {"name": "test-pod", "namespace": "default"},
+ "status": {"phase": "Succeeded"},
+ },
+ "pending": {
+ "metadata": {"name": "test-pod", "namespace": "default"},
+ "status": {"phase": "Pending"},
+ },
+ "running": {
+ "metadata": {"name": "test-pod", "namespace": "default"},
+ "status": {"phase": "Running"},
+ },
+}
+NOT_READY_DEPLOYMENT = V1Deployment(
+ status=V1DeploymentStatus(
+ observed_generation=1,
+ ready_replicas=None,
+ replicas=None,
+ unavailable_replicas=1,
+ updated_replicas=None,
+ )
+)
+READY_DEPLOYMENT = V1Deployment(
+ status=V1DeploymentStatus(
+ observed_generation=1, ready_replicas=1, replicas=1,
unavailable_replicas=None, updated_replicas=1
+ )
+)
@pytest.mark.db_test
@@ -298,6 +409,84 @@ class TestGKEHook:
operation_mock.assert_any_call(pending_op.name,
project_id=TEST_GCP_PROJECT_ID)
assert operation_mock.call_count == 2
+ @pytest.mark.parametrize(
+ "cluster_obj, expected_result",
+ [
+ (CLUSTER_TEST_AUTOPROVISIONING, True),
+ (CLUSTER_TEST_AUTOSCALED, True),
+ (CLUSTER_TEST_AUTOPILOT, True),
+ (CLUSTER_TEST_REGULAR, False),
+ ],
+ )
+ def test_check_cluster_autoscaling_ability(self, cluster_obj,
expected_result):
+ result = self.gke_hook.check_cluster_autoscaling_ability(cluster_obj)
+ assert result == expected_result
+
+
+class TestGKEDeploymentHook:
+ def setup_method(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
+ ):
+ self.gke_hook = GKEDeploymentHook(gcp_conn_id="test",
ssl_ca_cert=None, cluster_url=None)
+ self.gke_hook._client = mock.Mock()
+
+ def refresh_token(request):
+ self.credentials.token = "New"
+
+ self.credentials = mock.MagicMock()
+ self.credentials.token = "Old"
+ self.credentials.expired = False
+ self.credentials.refresh = refresh_token
+
+ @mock.patch(GKE_STRING.format("google_requests.Request"))
+ def test_get_connection_update_hook_with_invalid_token(self, mock_request):
+ self.gke_hook._get_config = self._get_config
+ self.gke_hook.get_credentials = self._get_credentials
+ self.gke_hook.get_credentials().expired = True
+ the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+ the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+ assert self.gke_hook.get_credentials().token == "New"
+
+ @mock.patch(GKE_STRING.format("google_requests.Request"))
+ def test_get_connection_update_hook_with_valid_token(self, mock_request):
+ self.gke_hook._get_config = self._get_config
+ self.gke_hook.get_credentials = self._get_credentials
+ self.gke_hook.get_credentials().expired = False
+ the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+ the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+ assert self.gke_hook.get_credentials().token == "Old"
+
+ def _get_config(self):
+ return kubernetes.client.configuration.Configuration()
+
+ def _get_credentials(self):
+ return self.credentials
+
+ @mock.patch("kubernetes.client.AppsV1Api")
+ def test_check_kueue_deployment_running(self, gke_deployment_hook, caplog):
+ self.gke_hook.get_credentials = self._get_credentials
+
gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect
= [
+ NOT_READY_DEPLOYMENT,
+ READY_DEPLOYMENT,
+ ]
+ self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME,
namespace=NAMESPACE)
+
+ assert "Waiting until Deployment will be ready..." in caplog.text
+
+ @mock.patch("kubernetes.client.AppsV1Api")
+ def test_check_kueue_deployment_raise_exception(self, gke_deployment_hook,
caplog):
+ self.gke_hook.get_credentials = self._get_credentials
+
gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect
= ValueError()
+ with pytest.raises(ValueError):
+ self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME,
namespace=NAMESPACE)
+
+ assert "Exception occurred while checking for Deployment status." in
caplog.text
+
class TestGKEPodAsyncHook:
@staticmethod
diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
index 7805804485..30a5247a68 100644
--- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py
@@ -22,6 +22,9 @@ import os
from unittest import mock
import pytest
+from google.cloud.container_v1.types import Cluster, NodePool
+from kubernetes.client.models import V1Deployment, V1DeploymentStatus
+from kubernetes.utils.create_from_yaml import FailToCreateError
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
@@ -30,6 +33,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager
import OnFinishAction
from airflow.providers.google.cloud.operators.kubernetes_engine import (
GKECreateClusterOperator,
GKEDeleteClusterOperator,
+ GKEStartKueueInsideClusterOperator,
GKEStartPodOperator,
)
from airflow.providers.google.cloud.triggers.kubernetes_engine import
GKEStartPodTrigger
@@ -46,12 +50,10 @@ PROJECT_BODY_CREATE_DICT_NODE_POOLS = {
"node_pools": [{"name": "a_node_pool", "initial_node_count": 1}],
}
-PROJECT_BODY_CREATE_CLUSTER = type("Cluster", (object,), {"name": "test-name",
"initial_node_count": 1})()
-PROJECT_BODY_CREATE_CLUSTER_NODE_POOLS = type(
- "Cluster",
- (object,),
- {"name": "test-name", "node_pools": [{"name": "a_node_pool",
"initial_node_count": 1}]},
-)()
+PROJECT_BODY_CREATE_CLUSTER = Cluster(name="test-name", initial_node_count=1)
+PROJECT_BODY_CREATE_CLUSTER_NODE_POOLS = Cluster(
+ name="test-name", node_pools=[NodePool(name="a_node_pool",
initial_node_count=1)]
+)
TASK_NAME = "test-task-name"
NAMESPACE = ("default",)
@@ -63,12 +65,28 @@ FILE_NAME = "/tmp/mock_name"
KUB_OP_PATH =
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.{}"
GKE_HOOK_MODULE_PATH =
"airflow.providers.google.cloud.operators.kubernetes_engine"
GKE_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook"
+GKE_POD_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEPodHook"
+GKE_DEPLOYMENT_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEDeploymentHook"
KUB_OPERATOR_EXEC =
"airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute"
TEMP_FILE = "tempfile.NamedTemporaryFile"
GKE_OP_PATH =
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator"
+GKE_CREATE_CLUSTER_PATH = (
+
"airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator"
+)
+GKE_CLUSTER_AUTH_DETAILS_PATH = (
+
"airflow.providers.google.cloud.operators.kubernetes_engine.GKEClusterAuthDetails"
+)
CLUSTER_URL = "https://test-host"
CLUSTER_PRIVATE_URL = "https://test-private-host"
SSL_CA_CERT = "TEST_SSL_CA_CERT_CONTENT"
+KUEUE_VERSION = "v0.5.1"
+IMPERSONATION_CHAIN = "[email protected]"
+USE_INTERNAL_API = False
+READY_DEPLOYMENT = V1Deployment(
+ status=V1DeploymentStatus(
+ observed_generation=1, ready_replicas=1, replicas=1,
unavailable_replicas=None, updated_replicas=1
+ )
+)
class TestGoogleCloudPlatformContainerOperator:
@@ -83,6 +101,7 @@ class TestGoogleCloudPlatformContainerOperator:
)
@mock.patch(GKE_HOOK_PATH)
def test_create_execute(self, mock_hook, body):
+ print("type: ", type(body))
operator = GKECreateClusterOperator(
project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION,
body=body, task_id=PROJECT_TASK_ID
)
@@ -420,6 +439,166 @@ class TestGKEPodOperator:
assert op.__getattribute__(expected_attr) ==
expected_attributes[expected_attr]
+class TestGKEStartKueueInsideClusterOperator:
+ @pytest.fixture(autouse=True)
+ def setup_test(self):
+ self.gke_op = GKEStartKueueInsideClusterOperator(
+ project_id=TEST_GCP_PROJECT_ID,
+ location=PROJECT_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ task_id=PROJECT_TASK_ID,
+ kueue_version=KUEUE_VERSION,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ use_internal_ip=USE_INTERNAL_API,
+ )
+ self.gke_op._cluster_url = CLUSTER_URL
+ self.gke_op._ssl_ca_cert = SSL_CA_CERT
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ @mock.patch(f"{GKE_DEPLOYMENT_HOOK_PATH}.check_kueue_deployment_running")
+ @mock.patch(GKE_POD_HOOK_PATH)
+ def test_execute(self, mock_pod_hook, mock_deployment, mock_hook,
fetch_cluster_info_mock, file_mock):
+ mock_pod_hook.return_value.apply_from_yaml_file.side_effect =
mock.MagicMock()
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value =
PROJECT_BODY_CREATE_CLUSTER
+ self.gke_op.execute(context=mock.MagicMock())
+
+ fetch_cluster_info_mock.assert_called_once()
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_DEPLOYMENT_HOOK_PATH)
+ @mock.patch(GKE_HOOK_PATH)
+ @mock.patch(GKE_POD_HOOK_PATH)
+ def test_execute_autoscaled_cluster(
+ self, mock_pod_hook, mock_hook, mock_depl_hook,
fetch_cluster_info_mock, file_mock, caplog
+ ):
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value = mock.MagicMock()
+ mock_pod_hook.return_value.apply_from_yaml_file.side_effect =
mock.MagicMock()
+ mock_hook.return_value.check_cluster_autoscaling_ability.return_value
= True
+ mock_depl_hook.return_value.get_deployment_status.return_value =
READY_DEPLOYMENT
+ self.gke_op.execute(context=mock.MagicMock())
+
+ assert "Kueue installed successfully!" in caplog.text
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ @mock.patch(GKE_POD_HOOK_PATH)
+ def test_execute_autoscaled_cluster_check_error(
+ self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, file_mock,
caplog
+ ):
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value = mock.MagicMock()
+ mock_hook.return_value.check_cluster_autoscaling_ability.return_value
= True
+ mock_pod_hook.return_value.apply_from_yaml_file.side_effect =
FailToCreateError("error")
+ self.gke_op.execute(context=mock.MagicMock())
+
+ assert "Kueue is already enabled for the cluster" in caplog.text
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ @mock.patch(GKE_POD_HOOK_PATH)
+ def test_execute_non_autoscaled_cluster_check_error(
+ self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, file_mock,
caplog
+ ):
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value = mock.MagicMock()
+ mock_hook.return_value.check_cluster_autoscaling_ability.return_value
= False
+ self.gke_op.execute(context=mock.MagicMock())
+
+ assert (
+ "Cluster doesn't have ability to autoscale, will not install Kueue
inside. Aborting"
+ in caplog.text
+ )
+ mock_pod_hook.assert_not_called()
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(
+ "airflow.hooks.base.BaseHook.get_connections",
+ return_value=[Connection(extra=json.dumps({"keyfile_dict":
'{"private_key": "r4nd0m_k3y"}'}))],
+ )
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ def test_execute_with_impersonation_service_account(
+ self, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock
+ ):
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value =
PROJECT_BODY_CREATE_CLUSTER
+ mock_hook.return_value.check_cluster_autoscaling_ability.return_value
= False
+ self.gke_op.impersonation_chain = "[email protected]"
+ self.gke_op.execute(context=mock.MagicMock())
+
+ fetch_cluster_info_mock.assert_called_once()
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(
+ "airflow.hooks.base.BaseHook.get_connections",
+ return_value=[Connection(extra=json.dumps({"keyfile_dict":
'{"private_key": "r4nd0m_k3y"}'}))],
+ )
+ @mock.patch(TEMP_FILE)
+ @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info")
+ @mock.patch(GKE_HOOK_PATH)
+ def test_execute_with_impersonation_service_chain_one_element(
+ self, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock
+ ):
+ fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT)
+ mock_hook.return_value.get_cluster.return_value =
PROJECT_BODY_CREATE_CLUSTER
+ mock_hook.return_value.check_cluster_autoscaling_ability.return_value
= False
+ self.gke_op.impersonation_chain = ["[email protected]"]
+ self.gke_op.execute(context=mock.MagicMock())
+
+ fetch_cluster_info_mock.assert_called_once()
+
+ @pytest.mark.db_test
+ def test_default_gcp_conn_id(self):
+ gke_op = GKEStartKueueInsideClusterOperator(
+ project_id=TEST_GCP_PROJECT_ID,
+ location=PROJECT_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ task_id=PROJECT_TASK_ID,
+ kueue_version=KUEUE_VERSION,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ use_internal_ip=USE_INTERNAL_API,
+ )
+ gke_op._cluster_url = CLUSTER_URL
+ gke_op._ssl_ca_cert = SSL_CA_CERT
+ hook = gke_op.cluster_hook
+
+ assert hook.gcp_conn_id == "google_cloud_default"
+
+ @mock.patch.dict(os.environ, {})
+ @mock.patch(
+ "airflow.hooks.base.BaseHook.get_connection",
+ return_value=Connection(extra=json.dumps({"keyfile_dict":
'{"private_key": "r4nd0m_k3y"}'})),
+ )
+ def test_gcp_conn_id(self, mock_get_credentials):
+ gke_op = GKEStartKueueInsideClusterOperator(
+ project_id=TEST_GCP_PROJECT_ID,
+ location=PROJECT_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ task_id=PROJECT_TASK_ID,
+ kueue_version=KUEUE_VERSION,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ use_internal_ip=USE_INTERNAL_API,
+ gcp_conn_id="test_conn",
+ )
+ gke_op._cluster_url = CLUSTER_URL
+ gke_op._ssl_ca_cert = SSL_CA_CERT
+ hook = gke_op.cluster_hook
+
+ assert hook.gcp_conn_id == "test_conn"
+
+
class TestGKEPodOperatorAsync:
def setup_method(self):
self.gke_op = GKEStartPodOperator(
@@ -443,7 +622,7 @@ class TestGKEPodOperatorAsync:
@mock.patch(KUB_OP_PATH.format("build_pod_request_obj"))
@mock.patch(KUB_OP_PATH.format("get_or_create_pod"))
@mock.patch(
- "airflow.hooks.base.BaseHook.get_connections",
+ "airflow.hooks.base.BaseHook.get_connection",
return_value=[Connection(extra=json.dumps({"keyfile_dict":
'{"private_key": "r4nd0m_k3y"}'}))],
)
@mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info")
diff --git
a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
new file mode 100644
index 0000000000..d87d0bd7d7
--- /dev/null
+++
b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for Google Kubernetes Engine.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.kubernetes_engine import (
+ GKECreateClusterOperator,
+ GKEDeleteClusterOperator,
+ GKEStartKueueInsideClusterOperator,
+)
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "example_kubernetes_engine_kueue"
+GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+
+GCP_LOCATION = "europe-west3"
+CLUSTER_NAME = f"cluster-name-test-kueue-{ENV_ID}".replace("_", "-")
+CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1, "autopilot":
{"enabled": True}}
+
+with DAG(
+ DAG_ID,
+ schedule="@once", # Override to match your needs
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "kubernetes-engine", "kueue"],
+) as dag:
+ create_cluster = GKECreateClusterOperator(
+ task_id="create_cluster",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ body=CLUSTER,
+ )
+
+ # [START howto_operator_gke_install_kueue]
+ add_kueue_cluster = GKEStartKueueInsideClusterOperator(
+ task_id="add_kueue_cluster",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ kueue_version="v0.5.1",
+ )
+ # [END howto_operator_gke_install_kueue]
+
+ delete_cluster = GKEDeleteClusterOperator(
+ task_id="delete_cluster",
+ name=CLUSTER_NAME,
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ )
+
+ create_cluster >> add_kueue_cluster >> delete_cluster
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)