This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9d307102b4 More typing and minor refactor for kubernetes (#24719)
9d307102b4 is described below
commit 9d307102b4a604034d9b1d7f293884821263575f
Author: Jed Cunningham <[email protected]>
AuthorDate: Wed Jun 29 16:02:47 2022 -0600
More typing and minor refactor for kubernetes (#24719)
---
airflow/kubernetes/pod_generator.py | 2 +-
.../providers/cncf/kubernetes/hooks/kubernetes.py | 17 ++++++++-------
.../cncf/kubernetes/operators/kubernetes_pod.py | 24 ++++++++++++++--------
.../cncf/kubernetes/hooks/test_kubernetes.py | 6 ++----
4 files changed, 26 insertions(+), 23 deletions(-)
diff --git a/airflow/kubernetes/pod_generator.py
b/airflow/kubernetes/pod_generator.py
index 8a86919a65..705c108b79 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -42,7 +42,7 @@ from airflow.version import version as airflow_version
MAX_LABEL_LEN = 63
-def make_safe_label_value(string):
+def make_safe_label_value(string: str) -> str:
"""
Valid label values must be 63 characters or less and must be empty or
begin and
end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-),
underscores (_),
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 5752de5041..725343211b 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -169,7 +169,7 @@ class KubernetesHook(BaseHook):
DeprecationWarning,
)
- def get_conn(self) -> Any:
+ def get_conn(self) -> client.ApiClient:
"""Returns kubernetes api session for use with requests"""
in_cluster = self._coalesce_param(
self.in_cluster,
self.conn_extras.get("extra__kubernetes__in_cluster") or None
@@ -258,7 +258,7 @@ class KubernetesHook(BaseHook):
return self._get_default_client(cluster_context=cluster_context)
- def _get_default_client(self, *, cluster_context=None):
+ def _get_default_client(self, *, cluster_context: Optional[str] = None) ->
client.ApiClient:
# if we get here, then no configuration has been supplied
# we should try in_cluster since that's most likely
# but failing that just load assuming a kubeconfig file
@@ -276,20 +276,21 @@ class KubernetesHook(BaseHook):
return client.ApiClient()
@property
- def is_in_cluster(self):
+ def is_in_cluster(self) -> bool:
"""Expose whether the hook is configured with
``load_incluster_config`` or not"""
if self._is_in_cluster is not None:
return self._is_in_cluster
self.api_client # so we can determine if we are in_cluster or not
+ assert self._is_in_cluster is not None
return self._is_in_cluster
@cached_property
- def api_client(self) -> Any:
+ def api_client(self) -> client.ApiClient:
"""Cached Kubernetes API client"""
return self.get_conn()
@cached_property
- def core_v1_client(self):
+ def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(api_client=self.api_client)
def create_custom_object(
@@ -377,12 +378,11 @@ class KubernetesHook(BaseHook):
:param container: container name
:param namespace: kubernetes namespace
"""
- api = client.CoreV1Api(self.api_client)
watcher = watch.Watch()
return (
watcher,
watcher.stream(
- api.read_namespaced_pod_log,
+ self.core_v1_client.read_namespaced_pod_log,
name=pod_name,
container=container,
namespace=namespace if namespace else self.get_namespace(),
@@ -402,8 +402,7 @@ class KubernetesHook(BaseHook):
:param container: container name
:param namespace: kubernetes namespace
"""
- api = client.CoreV1Api(self.api_client)
- return api.read_namespaced_pod_log(
+ return self.core_v1_client.read_namespaced_pod_log(
name=pod_name,
container=container,
_preload_content=False,
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 09cad504fe..9d95ffbd59 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -300,7 +300,9 @@ class KubernetesPodOperator(BaseOperator):
super()._render_nested_template_fields(content, context, jinja_env,
seen_oids)
@staticmethod
- def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number:
bool = True) -> dict:
+ def _get_ti_pod_labels(
+ context: Optional['Context'] = None, include_try_number: bool = True
+ ) -> Dict[str, str]:
"""
Generate labels for the pod to track the pod in case of Operator crash
@@ -360,7 +362,9 @@ class KubernetesPodOperator(BaseOperator):
def client(self) -> CoreV1Api:
return self.hook.core_v1_client
- def find_pod(self, namespace, context, *, exclude_checked=True) ->
Optional[k8s.V1Pod]:
+ def find_pod(
+ self, namespace: str, context: 'Context', *, exclude_checked: bool =
True
+ ) -> Optional[k8s.V1Pod]:
"""Returns an already-running pod for this task instance if one
exists."""
label_selector = self._build_find_pod_label_selector(context,
exclude_checked=exclude_checked)
pod_list = self.client.list_namespaced_pod(
@@ -379,7 +383,7 @@ class KubernetesPodOperator(BaseOperator):
self.log.info("`try_number` of pod: %s",
pod.metadata.labels['try_number'])
return pod
- def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
+ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context:
'Context') -> k8s.V1Pod:
if self.reattach_on_restart:
pod = self.find_pod(self.namespace or
pod_request_obj.metadata.namespace, context=context)
if pod:
@@ -388,7 +392,7 @@ class KubernetesPodOperator(BaseOperator):
self.pod_manager.create_pod(pod=pod_request_obj)
return pod_request_obj
- def await_pod_start(self, pod):
+ def await_pod_start(self, pod: k8s.V1Pod):
try:
self.pod_manager.await_pod_start(pod=pod,
startup_timeout=self.startup_timeout_seconds)
except PodLaunchFailedException:
@@ -397,7 +401,7 @@ class KubernetesPodOperator(BaseOperator):
self.log.error("Pod Event: %s - %s", event.reason,
event.message)
raise
- def extract_xcom(self, pod):
+ def extract_xcom(self, pod: k8s.V1Pod):
"""Retrieves xcom value and kills xcom sidecar container"""
result = self.pod_manager.extract_xcom(pod)
self.log.info("xcom result: \n%s", result)
@@ -461,7 +465,7 @@ class KubernetesPodOperator(BaseOperator):
with _suppress(Exception):
self.process_pod_deletion(remote_pod)
- def process_pod_deletion(self, pod):
+ def process_pod_deletion(self, pod: k8s.V1Pod):
if pod is not None:
if self.is_delete_operator_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
@@ -469,7 +473,9 @@ class KubernetesPodOperator(BaseOperator):
else:
self.log.info("skipping deleting pod: %s", pod.metadata.name)
- def _build_find_pod_label_selector(self, context: Optional[dict] = None,
*, exclude_checked=True) -> str:
+ def _build_find_pod_label_selector(
+ self, context: Optional['Context'] = None, *, exclude_checked=True
+ ) -> str:
labels = self._get_ti_pod_labels(context, include_try_number=False)
label_strings = [f'{label_id}={label}' for label_id, label in
sorted(labels.items())]
labels_value = ','.join(label_strings)
@@ -478,7 +484,7 @@ class KubernetesPodOperator(BaseOperator):
labels_value += ',!airflow-worker'
return labels_value
- def _set_name(self, name):
+ def _set_name(self, name: Optional[str]) -> Optional[str]:
if name is None:
if self.pod_template_file or self.full_pod_spec:
return None
@@ -504,7 +510,7 @@ class KubernetesPodOperator(BaseOperator):
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.client.delete_namespaced_pod(**kwargs)
- def build_pod_request_obj(self, context=None):
+ def build_pod_request_obj(self, context: Optional['Context'] = None) ->
k8s.V1Pod:
"""
Returns V1Pod object based on pod template file, full pod spec, and
other operator parameters.
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index 6bbe5926e9..44f815379c 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -106,10 +106,8 @@ class TestKubernetesHook:
else:
mock_get_default_client.assert_called()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
- if mock_get_default_client.called:
- # get_default_client sets it, but it's mocked
- assert kubernetes_hook.is_in_cluster is None
- else:
+ if not mock_get_default_client.called:
+ # get_default_client is mocked, so only check is_in_cluster if it
isn't called
assert kubernetes_hook.is_in_cluster is in_cluster_called
@pytest.mark.parametrize('in_cluster_fails', [True, False])