This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new c4a106e69b Create GKESuspendJobOperator and GKEResumeJobOperator operators (#38677) c4a106e69b is described below commit c4a106e69bbc396d2527a3b8c94e2d95fced4284 Author: Maksim <maks...@google.com> AuthorDate: Fri Apr 12 11:51:03 2024 +0200 Create GKESuspendJobOperator and GKEResumeJobOperator operators (#38677) --- .../google/cloud/operators/kubernetes_engine.py | 211 +++++++++++++++++++- .../operators/cloud/kubernetes_engine.rst | 30 +++ .../cloud/operators/test_kubernetes_engine.py | 219 +++++++++++++++++++++ .../example_kubernetes_engine_job.py | 26 +++ 4 files changed, 485 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 8729335c5a..4b4261ff8d 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -29,7 +29,7 @@ import yaml from deprecated import deprecated from google.api_core.exceptions import AlreadyExists from google.cloud.container_v1.types import Cluster -from kubernetes.client import V1JobList +from kubernetes.client import V1JobList, models as k8s from kubernetes.utils.create_from_yaml import FailToCreateError from packaging.version import parse as parse_version @@ -47,6 +47,7 @@ from airflow.providers.google.cloud.hooks.kubernetes_engine import ( GKEDeploymentHook, GKEHook, GKEJobHook, + GKEKubernetesHook, GKEPodHook, ) from airflow.providers.google.cloud.links.kubernetes_engine import ( @@ -1494,3 +1495,211 @@ class GKEDeleteJobOperator(KubernetesDeleteJobOperator): ).fetch_cluster_info() return super().execute(context) + + +class GKESuspendJobOperator(GoogleCloudBaseOperator): + """ + Suspend Job by given name. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKESuspendJobOperator` + + :param name: The name of the Job to suspend + :param project_id: The Google Developers Console project id. + :param location: The name of the Google Kubernetes Engine zone or region in which the cluster + resides. + :param cluster_name: The name of the Google Kubernetes Engine cluster. + :param namespace: The name of the Google Kubernetes Engine namespace. + :param use_internal_ip: Use the internal IP address as the endpoint. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "gcp_conn_id", + "name", + "namespace", + "cluster_name", + "location", + "impersonation_chain", + ) + operator_extra_links = (KubernetesEngineJobLink(),) + + def __init__( + self, + *, + name: str, + location: str, + namespace: str, + cluster_name: str, + project_id: str | None = None, + use_internal_ip: bool = False, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.name = name + self.namespace = namespace + self.cluster_name = cluster_name + self.use_internal_ip = use_internal_ip + self.impersonation_chain = impersonation_chain + + self.job: V1Job | None = None + self._ssl_ca_cert: str + self._cluster_url: str + + @cached_property + def cluster_hook(self) -> GKEHook: + return GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + @cached_property + def hook(self) -> GKEKubernetesHook: + self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( + cluster_name=self.cluster_name, + project_id=self.project_id, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ).fetch_cluster_info() + + return GKEKubernetesHook( + gcp_conn_id=self.gcp_conn_id, + cluster_url=self._cluster_url, + ssl_ca_cert=self._ssl_ca_cert, + ) + + def execute(self, context: Context) -> None: + self.job = self.hook.patch_namespaced_job( + job_name=self.name, + namespace=self.namespace, + body={"spec": {"suspend": True}}, + ) + self.log.info( + "Job %s from cluster %s was suspended.", + self.name, + self.cluster_name, + ) + KubernetesEngineJobLink.persist(context=context, task_instance=self) + + return k8s.V1Job.to_dict(self.job) + + +class GKEResumeJobOperator(GoogleCloudBaseOperator): + """ + Resume Job by given name. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKEResumeJobOperator` + + :param name: The name of the Job to resume + :param project_id: The Google Developers Console project id. + :param location: The name of the Google Kubernetes Engine zone or region in which the cluster + resides. + :param cluster_name: The name of the Google Kubernetes Engine cluster. + :param namespace: The name of the Google Kubernetes Engine namespace. + :param use_internal_ip: Use the internal IP address as the endpoint. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "gcp_conn_id", + "name", + "namespace", + "cluster_name", + "location", + "impersonation_chain", + ) + operator_extra_links = (KubernetesEngineJobLink(),) + + def __init__( + self, + *, + name: str, + location: str, + namespace: str, + cluster_name: str, + project_id: str | None = None, + use_internal_ip: bool = False, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.location = location + self.name = name + self.namespace = namespace + self.cluster_name = cluster_name + self.use_internal_ip = use_internal_ip + self.impersonation_chain = impersonation_chain + + self.job: V1Job | None = None + self._ssl_ca_cert: str + self._cluster_url: str + + @cached_property + def cluster_hook(self) -> GKEHook: + return GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + @cached_property + def hook(self) -> GKEKubernetesHook: + self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( + cluster_name=self.cluster_name, + project_id=self.project_id, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ).fetch_cluster_info() + + return GKEKubernetesHook( + gcp_conn_id=self.gcp_conn_id, + cluster_url=self._cluster_url, + ssl_ca_cert=self._ssl_ca_cert, + ) + + def execute(self, context: Context) -> None: + self.job = self.hook.patch_namespaced_job( + job_name=self.name, + namespace=self.namespace, + body={"spec": {"suspend": False}}, + ) + self.log.info( + "Job %s from cluster %s was resumed.", + self.name, + self.cluster_name, + ) + KubernetesEngineJobLink.persist(context=context, task_instance=self) + + return k8s.V1Job.to_dict(self.job) diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index 7d5cc44826..d71ebf87e9 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -312,6 +312,36 @@ delete resource in the specified Google Kubernetes Engine cluster. :start-after: [START howto_operator_gke_delete_resource] :end-before: [END howto_operator_gke_delete_resource] + +.. _howto/operator:GKESuspendJobOperator: + +Suspend a Job on a GKE cluster +"""""""""""""""""""""""""""""" + +You can use :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKESuspendJobOperator` to +suspend Job in the specified Google Kubernetes Engine cluster. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gke_suspend_job] + :end-before: [END howto_operator_gke_suspend_job] + + +.. _howto/operator:GKEResumeJobOperator: + +Resume a Job on a GKE cluster +""""""""""""""""""""""""""""" + +You can use :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEResumeJobOperator` to +resume Job in the specified Google Kubernetes Engine cluster. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_gke_resume_job] + :end-before: [END howto_operator_gke_resume_job] + Reference ^^^^^^^^^ diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 442cdde68e..a6d0de56c4 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -42,10 +42,12 @@ from airflow.providers.google.cloud.operators.kubernetes_engine import ( GKEDeleteCustomResourceOperator, GKEDeleteJobOperator, GKEDescribeJobOperator, + GKEResumeJobOperator, GKEStartJobOperator, GKEStartKueueInsideClusterOperator, GKEStartKueueJobOperator, GKEStartPodOperator, + GKESuspendJobOperator, ) from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger @@ -82,6 +84,7 @@ GKE_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook" GKE_POD_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEPodHook" GKE_DEPLOYMENT_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEDeploymentHook" GKE_JOB_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEJobHook" +GKE_K8S_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEKubernetesHook" KUB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute" KUB_JOB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator.execute" KUB_CREATE_RES_OPERATOR_EXEC = ( @@ -1319,3 +1322,219 @@ class TestGKEDeleteJobOperator: hook = gke_op.hook assert hook.gcp_conn_id == "test_conn" + + +class TestGKESuspendJobOperator: + def setup_method(self): + self.gke_op = GKESuspendJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + ) + + def test_config_file_throws_error(self): + with pytest.raises(AirflowException): + GKESuspendJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + config_file="/path/to/alternative/kubeconfig", + ) + + @mock.patch.dict(os.environ, {}) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute(self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock): + mock_job_hook.return_value.get_job.return_value = mock.MagicMock() + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute_with_impersonation_service_account( + self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock + ): + mock_job_hook.return_value.get_job.return_value = mock.MagicMock() + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = "test_acco...@example.com" + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute_with_impersonation_service_chain_one_element( + self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = ["test_acco...@example.com"] + self.gke_op.execute(context=mock.MagicMock()) + + fetch_cluster_info_mock.assert_called_once() + + @pytest.mark.db_test + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + def test_default_gcp_conn_id(self, fetch_cluster_info_mock): + gke_op = GKESuspendJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + ) + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + hook = gke_op.hook + + assert hook.gcp_conn_id == "google_cloud_default" + + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), + ) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + def test_gcp_conn_id(self, mock_hook, fetch_cluster_info_mock, mock_gke_conn): + gke_op = GKESuspendJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + gcp_conn_id="test_conn", + ) + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + hook = gke_op.hook + + assert hook.gcp_conn_id == "test_conn" + + +class TestGKEResumeJobOperator: + def setup_method(self): + self.gke_op = GKEResumeJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + ) + + def test_config_file_throws_error(self): + with pytest.raises(AirflowException): + GKEResumeJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + config_file="/path/to/alternative/kubeconfig", + ) + + @mock.patch.dict(os.environ, {}) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute(self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock): + mock_job_hook.return_value.get_job.return_value = mock.MagicMock() + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute_with_impersonation_service_account( + self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock + ): + mock_job_hook.return_value.get_job.return_value = mock.MagicMock() + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = "test_acco...@example.com" + self.gke_op.execute(context=mock.MagicMock()) + fetch_cluster_info_mock.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(TEMP_FILE) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_K8S_HOOK_PATH) + def test_execute_with_impersonation_service_chain_one_element( + self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = ["test_acco...@example.com"] + self.gke_op.execute(context=mock.MagicMock()) + + fetch_cluster_info_mock.assert_called_once() + + @pytest.mark.db_test + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + def test_default_gcp_conn_id(self, fetch_cluster_info_mock): + gke_op = GKEResumeJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + ) + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + hook = gke_op.hook + + assert hook.gcp_conn_id == "google_cloud_default" + + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), + ) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + def test_gcp_conn_id(self, mock_hook, fetch_cluster_info_mock, mock_gke_conn): + gke_op = GKEResumeJobOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + name=TASK_NAME, + namespace=NAMESPACE, + gcp_conn_id="test_conn", + ) + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + hook = gke_op.hook + + assert hook.gcp_conn_id == "test_conn" diff --git a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py index 97b5cd406b..2c7790c102 100644 --- a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py +++ b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py @@ -32,7 +32,9 @@ from airflow.providers.google.cloud.operators.kubernetes_engine import ( GKEDeleteJobOperator, GKEDescribeJobOperator, GKEListJobsOperator, + GKEResumeJobOperator, GKEStartJobOperator, + GKESuspendJobOperator, ) from airflow.utils.trigger_rule import TriggerRule @@ -116,6 +118,28 @@ with DAG( cluster_name=CLUSTER_NAME, ) + # [START howto_operator_gke_suspend_job] + suspend_job = GKESuspendJobOperator( + task_id="suspend_job", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + name=job_task.output["job_name"], + namespace="default", + ) + # [END howto_operator_gke_suspend_job] + + # [START howto_operator_gke_resume_job] + resume_job = GKEResumeJobOperator( + task_id="resume_job", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + name=job_task.output["job_name"], + namespace="default", + ) + # [END howto_operator_gke_resume_job] + # [START howto_operator_gke_delete_job] delete_job = GKEDeleteJobOperator( task_id="delete_job", @@ -149,6 +173,8 @@ with DAG( [job_task, job_task_def], list_job_task, [describe_job_task, describe_job_task_def], + suspend_job, + resume_job, [delete_job, delete_job_def], delete_cluster, )