This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 60a1d9d191 [FEATURE] google provider - split GkeStartPodOperator
execute (#23518)
60a1d9d191 is described below
commit 60a1d9d191fb8fc01893024c897df9632ad5fbf4
Author: raphaelauv <[email protected]>
AuthorDate: Tue May 10 17:51:37 2022 +0200
[FEATURE] google provider - split GkeStartPodOperator execute (#23518)
---
.../google/cloud/operators/kubernetes_engine.py | 57 ++++++++++++++++------
1 file changed, 41 insertions(+), 16 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/airflow/providers/google/cloud/operators/kubernetes_engine.py
index 9ee718165a..83c013ba44 100644
--- a/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -21,7 +21,8 @@
import os
import tempfile
import warnings
-from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Dict, Generator, Optional, Sequence, Union
from google.cloud.container_v1.types import Cluster
@@ -336,11 +337,22 @@ class GKEStartPodOperator(KubernetesPodOperator):
if self.config_file:
raise AirflowException("config_file is not an allowed parameter
for the GKEStartPodOperator.")
- def execute(self, context: 'Context') -> Optional[str]:
- hook = GoogleBaseHook(gcp_conn_id=self.gcp_conn_id)
- self.project_id = self.project_id or hook.project_id
+ @staticmethod
+ @contextmanager
+ def get_gke_config_file(
+ gcp_conn_id,
+ project_id: Optional[str],
+ cluster_name: str,
+ impersonation_chain: Optional[Union[str, Sequence[str]]],
+ regional: bool,
+ location: str,
+ use_internal_ip: bool,
+ ) -> Generator[str, None, None]:
- if not self.project_id:
+ hook = GoogleBaseHook(gcp_conn_id=gcp_conn_id)
+ project_id = project_id or hook.project_id
+
+ if not project_id:
raise AirflowException(
"The project id must be passed either as "
"keyword project_id parameter or as project_id extra "
@@ -363,15 +375,15 @@ class GKEStartPodOperator(KubernetesPodOperator):
"container",
"clusters",
"get-credentials",
- self.cluster_name,
+ cluster_name,
"--project",
- self.project_id,
+ project_id,
]
- if self.impersonation_chain:
- if isinstance(self.impersonation_chain, str):
- impersonation_account = self.impersonation_chain
- elif len(self.impersonation_chain) == 1:
- impersonation_account = self.impersonation_chain[0]
+ if impersonation_chain:
+ if isinstance(impersonation_chain, str):
+ impersonation_account = impersonation_chain
+ elif len(impersonation_chain) == 1:
+ impersonation_account = impersonation_chain[0]
else:
raise AirflowException(
"Chained list of accounts is not supported, please
specify only one service account"
@@ -383,15 +395,28 @@ class GKEStartPodOperator(KubernetesPodOperator):
impersonation_account,
]
)
- if self.regional:
+ if regional:
cmd.append('--region')
else:
cmd.append('--zone')
- cmd.append(self.location)
- if self.use_internal_ip:
+ cmd.append(location)
+ if use_internal_ip:
cmd.append('--internal-ip')
execute_in_subprocess(cmd)
# Tell `KubernetesPodOperator` where the config file is located
- self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
+ yield os.environ[KUBE_CONFIG_ENV_VAR]
+
+ def execute(self, context: 'Context') -> Optional[str]:
+
+ with GKEStartPodOperator.get_gke_config_file(
+ gcp_conn_id=self.gcp_conn_id,
+ project_id=self.project_id,
+ cluster_name=self.cluster_name,
+ impersonation_chain=self.impersonation_chain,
+ regional=self.regional,
+ location=self.location,
+ use_internal_ip=self.use_internal_ip,
+ ) as config_file:
+ self.config_file = config_file
return super().execute(context)