This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 60a1d9d191 [FEATURE] google provider - split GkeStartPodOperator 
execute (#23518)
60a1d9d191 is described below

commit 60a1d9d191fb8fc01893024c897df9632ad5fbf4
Author: raphaelauv <[email protected]>
AuthorDate: Tue May 10 17:51:37 2022 +0200

    [FEATURE] google provider - split GkeStartPodOperator execute (#23518)
---
 .../google/cloud/operators/kubernetes_engine.py    | 57 ++++++++++++++++------
 1 file changed, 41 insertions(+), 16 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py 
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 9ee718165a..83c013ba44 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -21,7 +21,8 @@
 import os
 import tempfile
 import warnings
-from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union
 
 from google.cloud.container_v1.types import Cluster
 
@@ -336,11 +337,22 @@ class GKEStartPodOperator(KubernetesPodOperator):
         if self.config_file:
             raise AirflowException("config_file is not an allowed parameter 
for the GKEStartPodOperator.")
 
-    def execute(self, context: 'Context') -> Optional[str]:
-        hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id)
-        self.project_id = self.project_id or hook.project_id
+    @staticmethod
+    @contextmanager
+    def get_gke_config_file(
+        gcp_conn_id,
+        project_id: Optional[str],
+        cluster_name: str,
+        impersonation_chain: Optional[Union[str, Sequence[str]]],
+        regional: bool,
+        location: str,
+        use_internal_ip: bool,
+    ) -> Generator[str, None, None]:
 
-        if not self.project_id:
+        hook = GoogleBaseHook(gcp_conn_id=gcp_conn_id)
+        project_id = project_id or hook.project_id
+
+        if not project_id:
             raise AirflowException(
                 "The project id must be passed either as "
                 "keyword project_id parameter or as project_id extra "
@@ -363,15 +375,15 @@ class GKEStartPodOperator(KubernetesPodOperator):
                 "container",
                 "clusters",
                 "get-credentials",
-                self.cluster_name,
+                cluster_name,
                 "--project",
-                self.project_id,
+                project_id,
             ]
-            if self.impersonation_chain:
-                if isinstance(self.impersonation_chain, str):
-                    impersonation_account = self.impersonation_chain
-                elif len(self.impersonation_chain) == 1:
-                    impersonation_account = self.impersonation_chain[0]
+            if impersonation_chain:
+                if isinstance(impersonation_chain, str):
+                    impersonation_account = impersonation_chain
+                elif len(impersonation_chain) == 1:
+                    impersonation_account = impersonation_chain[0]
                 else:
                     raise AirflowException(
                         "Chained list of accounts is not supported, please 
specify only one service account"
@@ -383,15 +395,28 @@ class GKEStartPodOperator(KubernetesPodOperator):
                         impersonation_account,
                     ]
                 )
-            if self.regional:
+            if regional:
                 cmd.append('--region')
             else:
                 cmd.append('--zone')
-            cmd.append(self.location)
-            if self.use_internal_ip:
+            cmd.append(location)
+            if use_internal_ip:
                 cmd.append('--internal-ip')
             execute_in_subprocess(cmd)
 
             # Tell `KubernetesPodOperator` where the config file is located
-            self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
+            yield os.environ[KUBE_CONFIG_ENV_VAR]
+
+    def execute(self, context: 'Context') -> Optional[str]:
+
+        with GKEStartPodOperator.get_gke_config_file(
+            gcp_conn_id=self.gcp_conn_id,
+            project_id=self.project_id,
+            cluster_name=self.cluster_name,
+            impersonation_chain=self.impersonation_chain,
+            regional=self.regional,
+            location=self.location,
+            use_internal_ip=self.use_internal_ip,
+        ) as config_file:
+            self.config_file = config_file
             return super().execute(context)

Reply via email to