This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d307102b4 More typing and minor refactor for kubernetes (#24719)
9d307102b4 is described below

commit 9d307102b4a604034d9b1d7f293884821263575f
Author: Jed Cunningham <[email protected]>
AuthorDate: Wed Jun 29 16:02:47 2022 -0600

    More typing and minor refactor for kubernetes (#24719)
---
 airflow/kubernetes/pod_generator.py                |  2 +-
 .../providers/cncf/kubernetes/hooks/kubernetes.py  | 17 ++++++++-------
 .../cncf/kubernetes/operators/kubernetes_pod.py    | 24 ++++++++++++++--------
 .../cncf/kubernetes/hooks/test_kubernetes.py       |  6 ++----
 4 files changed, 26 insertions(+), 23 deletions(-)

diff --git a/airflow/kubernetes/pod_generator.py 
b/airflow/kubernetes/pod_generator.py
index 8a86919a65..705c108b79 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -42,7 +42,7 @@ from airflow.version import version as airflow_version
 MAX_LABEL_LEN = 63
 
 
-def make_safe_label_value(string):
+def make_safe_label_value(string: str) -> str:
     """
     Valid label values must be 63 characters or less and must be empty or 
begin and
     end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), 
underscores (_),
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py 
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 5752de5041..725343211b 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -169,7 +169,7 @@ class KubernetesHook(BaseHook):
             DeprecationWarning,
         )
 
-    def get_conn(self) -> Any:
+    def get_conn(self) -> client.ApiClient:
         """Returns kubernetes api session for use with requests"""
         in_cluster = self._coalesce_param(
             self.in_cluster, 
self.conn_extras.get("extra__kubernetes__in_cluster") or None
@@ -258,7 +258,7 @@ class KubernetesHook(BaseHook):
 
         return self._get_default_client(cluster_context=cluster_context)
 
-    def _get_default_client(self, *, cluster_context=None):
+    def _get_default_client(self, *, cluster_context: Optional[str] = None) -> 
client.ApiClient:
         # if we get here, then no configuration has been supplied
         # we should try in_cluster since that's most likely
         # but failing that just load assuming a kubeconfig file
@@ -276,20 +276,21 @@ class KubernetesHook(BaseHook):
         return client.ApiClient()
 
     @property
-    def is_in_cluster(self):
+    def is_in_cluster(self) -> bool:
         """Expose whether the hook is configured with 
``load_incluster_config`` or not"""
         if self._is_in_cluster is not None:
             return self._is_in_cluster
         self.api_client  # so we can determine if we are in_cluster or not
+        assert self._is_in_cluster is not None
         return self._is_in_cluster
 
     @cached_property
-    def api_client(self) -> Any:
+    def api_client(self) -> client.ApiClient:
         """Cached Kubernetes API client"""
         return self.get_conn()
 
     @cached_property
-    def core_v1_client(self):
+    def core_v1_client(self) -> client.CoreV1Api:
         return client.CoreV1Api(api_client=self.api_client)
 
     def create_custom_object(
@@ -377,12 +378,11 @@ class KubernetesHook(BaseHook):
         :param container: container name
         :param namespace: kubernetes namespace
         """
-        api = client.CoreV1Api(self.api_client)
         watcher = watch.Watch()
         return (
             watcher,
             watcher.stream(
-                api.read_namespaced_pod_log,
+                self.core_v1_client.read_namespaced_pod_log,
                 name=pod_name,
                 container=container,
                 namespace=namespace if namespace else self.get_namespace(),
@@ -402,8 +402,7 @@ class KubernetesHook(BaseHook):
         :param container: container name
         :param namespace: kubernetes namespace
         """
-        api = client.CoreV1Api(self.api_client)
-        return api.read_namespaced_pod_log(
+        return self.core_v1_client.read_namespaced_pod_log(
             name=pod_name,
             container=container,
             _preload_content=False,
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 09cad504fe..9d95ffbd59 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -300,7 +300,9 @@ class KubernetesPodOperator(BaseOperator):
         super()._render_nested_template_fields(content, context, jinja_env, 
seen_oids)
 
     @staticmethod
-    def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: 
bool = True) -> dict:
+    def _get_ti_pod_labels(
+        context: Optional['Context'] = None, include_try_number: bool = True
+    ) -> Dict[str, str]:
         """
         Generate labels for the pod to track the pod in case of Operator crash
 
@@ -360,7 +362,9 @@ class KubernetesPodOperator(BaseOperator):
     def client(self) -> CoreV1Api:
         return self.hook.core_v1_client
 
-    def find_pod(self, namespace, context, *, exclude_checked=True) -> 
Optional[k8s.V1Pod]:
+    def find_pod(
+        self, namespace: str, context: 'Context', *, exclude_checked: bool = 
True
+    ) -> Optional[k8s.V1Pod]:
         """Returns an already-running pod for this task instance if one 
exists."""
         label_selector = self._build_find_pod_label_selector(context, 
exclude_checked=exclude_checked)
         pod_list = self.client.list_namespaced_pod(
@@ -379,7 +383,7 @@ class KubernetesPodOperator(BaseOperator):
             self.log.info("`try_number` of pod: %s", 
pod.metadata.labels['try_number'])
         return pod
 
-    def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
+    def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: 
'Context') -> k8s.V1Pod:
         if self.reattach_on_restart:
             pod = self.find_pod(self.namespace or 
pod_request_obj.metadata.namespace, context=context)
             if pod:
@@ -388,7 +392,7 @@ class KubernetesPodOperator(BaseOperator):
         self.pod_manager.create_pod(pod=pod_request_obj)
         return pod_request_obj
 
-    def await_pod_start(self, pod):
+    def await_pod_start(self, pod: k8s.V1Pod):
         try:
             self.pod_manager.await_pod_start(pod=pod, 
startup_timeout=self.startup_timeout_seconds)
         except PodLaunchFailedException:
@@ -397,7 +401,7 @@ class KubernetesPodOperator(BaseOperator):
                     self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
             raise
 
-    def extract_xcom(self, pod):
+    def extract_xcom(self, pod: k8s.V1Pod):
         """Retrieves xcom value and kills xcom sidecar container"""
         result = self.pod_manager.extract_xcom(pod)
         self.log.info("xcom result: \n%s", result)
@@ -461,7 +465,7 @@ class KubernetesPodOperator(BaseOperator):
             with _suppress(Exception):
                 self.process_pod_deletion(remote_pod)
 
-    def process_pod_deletion(self, pod):
+    def process_pod_deletion(self, pod: k8s.V1Pod):
         if pod is not None:
             if self.is_delete_operator_pod:
                 self.log.info("Deleting pod: %s", pod.metadata.name)
@@ -469,7 +473,9 @@ class KubernetesPodOperator(BaseOperator):
             else:
                 self.log.info("skipping deleting pod: %s", pod.metadata.name)
 
-    def _build_find_pod_label_selector(self, context: Optional[dict] = None, 
*, exclude_checked=True) -> str:
+    def _build_find_pod_label_selector(
+        self, context: Optional['Context'] = None, *, exclude_checked=True
+    ) -> str:
         labels = self._get_ti_pod_labels(context, include_try_number=False)
         label_strings = [f'{label_id}={label}' for label_id, label in 
sorted(labels.items())]
         labels_value = ','.join(label_strings)
@@ -478,7 +484,7 @@ class KubernetesPodOperator(BaseOperator):
         labels_value += ',!airflow-worker'
         return labels_value
 
-    def _set_name(self, name):
+    def _set_name(self, name: Optional[str]) -> Optional[str]:
         if name is None:
             if self.pod_template_file or self.full_pod_spec:
                 return None
@@ -504,7 +510,7 @@ class KubernetesPodOperator(BaseOperator):
                 
kwargs.update(grace_period_seconds=self.termination_grace_period)
             self.client.delete_namespaced_pod(**kwargs)
 
-    def build_pod_request_obj(self, context=None):
+    def build_pod_request_obj(self, context: Optional['Context'] = None) -> 
k8s.V1Pod:
         """
         Returns V1Pod object based on pod template file, full pod spec, and 
other operator parameters.
 
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py 
b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index 6bbe5926e9..44f815379c 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -106,10 +106,8 @@ class TestKubernetesHook:
         else:
             mock_get_default_client.assert_called()
         assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
-        if mock_get_default_client.called:
-            # get_default_client sets it, but it's mocked
-            assert kubernetes_hook.is_in_cluster is None
-        else:
+        if not mock_get_default_client.called:
+            # get_default_client is mocked, so only check is_in_cluster if it 
isn't called
             assert kubernetes_hook.is_in_cluster is in_cluster_called
 
     @pytest.mark.parametrize('in_cluster_fails', [True, False])

Reply via email to