dstandish commented on a change in pull request #19572:
URL: https://github.com/apache/airflow/pull/19572#discussion_r772562490



##########
File path: airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
##########
@@ -328,101 +339,133 @@ def create_labels_for_pod(context) -> dict:
             labels[label_id] = safe_label
         return labels
 
-    def create_pod_launcher(self) -> Type[pod_launcher.PodLauncher]:
-        return pod_launcher.PodLauncher(kube_client=self.client, 
extract_xcom=self.do_xcom_push)
+    @cached_property
+    def launcher(self) -> pod_launcher.PodLauncher:
+        return pod_launcher.PodLauncher(kube_client=self.client)
 
-    def execute(self, context) -> Optional[str]:
+    @cached_property
+    def client(self) -> CoreV1Api:
+        # todo: use airflow Connection / hook to authenticate to the cluster
+        kwargs: Dict[str, Any] = dict(
+            cluster_context=self.cluster_context,
+            config_file=self.config_file,
+        )
+        if self.in_cluster is not None:
+            kwargs.update(in_cluster=self.in_cluster)
+        return kube_client.get_kube_client(**kwargs)
+
+    def find_pod(self, namespace, context) -> Optional[k8s.V1Pod]:
+        """Returns an already-running pod for this task instance if one 
exists."""
+        labels = self._create_labels_for_pod(context)
+        label_selector = self._get_pod_identifying_label_string(labels)
+        pod_list = self.client.list_namespaced_pod(
+            namespace=namespace,
+            label_selector=label_selector,
+        ).items
+
+        num_pods = len(pod_list)
+        if num_pods > 1:
+            raise AirflowException(f'More than one pod running with labels 
{label_selector}')
+        elif num_pods == 1:
+            pod = pod_list[0]
+            self.log.info("Found matching pod %s", pod.metadata.name)
+            self._compare_try_numbers(context, pod)
+            return pod
+
+    def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
+        if self.reattach_on_restart:
+            pod = self.find_pod(self.namespace or 
pod_request_obj.metadata.namespace, context=context)
+            if pod:
+                return pod
+        self.log.debug("Starting pod:\n%s", 
yaml.safe_dump(pod_request_obj.to_dict()))
+        self.launcher.create_pod(pod=pod_request_obj)
+        return pod_request_obj
+
+    def await_pod_start(self, pod):
         try:
-            if self.in_cluster is not None:
-                client = kube_client.get_kube_client(
-                    in_cluster=self.in_cluster,
-                    cluster_context=self.cluster_context,
-                    config_file=self.config_file,
-                )
-            else:
-                client = kube_client.get_kube_client(
-                    cluster_context=self.cluster_context, 
config_file=self.config_file
-                )
-
-            self.client = client
-
-            self.pod = self.create_pod_request_obj()
-            self.namespace = self.pod.metadata.namespace
-
-            # Add combination of labels to uniquely identify a running pod
-            labels = self.create_labels_for_pod(context)
-
-            label_selector = self._get_pod_identifying_label_string(labels)
-
-            pod_list = self.client.list_namespaced_pod(self.namespace, 
label_selector=label_selector)
+            self.launcher.await_pod_start(pod=pod, 
startup_timeout=self.startup_timeout_seconds)
+        except PodLaunchFailedException:
+            if self.log_events_on_failure:
+                for event in self.launcher.read_pod_events(pod).items:
+                    self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
+            raise
 
-            if len(pod_list.items) > 1 and self.reattach_on_restart:
-                raise AirflowException(
-                    f'More than one pod running with labels: {label_selector}'
-                )
+    def extract_xcom(self, pod):
+        """Retrieves xcom value and kills xcom sidecar container"""
+        result = self.launcher.extract_xcom(pod)
+        self.log.info("xcom result: \n%s", result)
+        return json.loads(result)
 
-            launcher = self.create_pod_launcher()
+    def execute(self, context):
+        remote_pod = None
+        try:
+            self.pod_request_obj = self.build_pod_request_obj(context)
+            self.pod = self.get_or_create_pod(  # must set `self.pod` for 
`on_kill`
+                pod_request_obj=self.pod_request_obj,
+                context=context,
+            )
+            self.await_pod_start(pod=self.pod)
 
-            if len(pod_list.items) == 1:
-                try_numbers_match = self._try_numbers_match(context, 
pod_list.items[0])
-                final_state, remote_pod, result = self.handle_pod_overlap(
-                    labels, try_numbers_match, launcher, pod_list.items[0]
+            if self.get_logs:
+                self.launcher.follow_container_logs(
+                    pod=self.pod,
+                    container_name=self.BASE_CONTAINER_NAME,
                 )
             else:
-                self.log.info("creating pod with labels %s and launcher %s", 
labels, launcher)
-                final_state, remote_pod, result = 
self.create_new_pod_for_operator(labels, launcher)
-            if final_state != State.SUCCESS:
-                raise AirflowException(f'Pod {self.pod.metadata.name} returned 
a failure: {remote_pod}')
-            context['task_instance'].xcom_push(key='pod_name', 
value=self.pod.metadata.name)
-            context['task_instance'].xcom_push(key='pod_namespace', 
value=self.namespace)
-            return result
-        except AirflowException as ex:
-            raise AirflowException(f'Pod Launching failed: {ex}')
+                self.launcher.await_container_completion(
+                    pod=self.pod, container_name=self.BASE_CONTAINER_NAME
+                )
 
-    def handle_pod_overlap(
-        self, labels: dict, try_numbers_match: bool, launcher: Any, pod: 
k8s.V1Pod
-    ) -> Tuple[State, k8s.V1Pod, Optional[str]]:
-        """
+            if self.do_xcom_push:
+                result = self.extract_xcom(pod=self.pod)
+            remote_pod = self.launcher.await_pod_completion(self.pod)
+        finally:
+            self.cleanup(
+                pod=self.pod or self.pod_request_obj,
+                remote_pod=remote_pod,
+            )
+        if self.do_xcom_push:
+            ti = context['ti']
+            if remote_pod:
+                ti.xcom_push(key='pod_name', value=remote_pod.metadata.name)
+                ti.xcom_push(key='pod_namespace', 
value=remote_pod.metadata.namespace)
+            return result
 
-        In cases where the Scheduler restarts while a KubernetesPodOperator 
task is running,
-        this function will either continue to monitor the existing pod or 
launch a new pod
-        based on the `reattach_on_restart` parameter.
+    def cleanup(self, pod, remote_pod):
+        with _suppress(Exception):
+            self.process_pod_deletion(pod)
 
-        :param labels: labels used to determine if a pod is repeated
-        :type labels: dict
-        :param try_numbers_match: do the try numbers match? Only needed for 
logging purposes
-        :type try_numbers_match: bool
-        :param launcher: PodLauncher
-        :param pod: Pod found with matching labels
-        """
-        if try_numbers_match:
-            log_line = f"found a running pod with labels {labels} and the same 
try_number."
-        else:
-            log_line = f"found a running pod with labels {labels} but a 
different try_number."
-
-        # In case of failed pods, should reattach the first time, but only once
-        # as the task will have already failed.
-        if self.reattach_on_restart and not 
pod.metadata.labels.get("already_checked"):
-            log_line += " Will attach to this pod and monitor instead of 
starting new one"
-            self.log.info(log_line)
-            self.pod = pod
-            final_state, remote_pod, result = 
self.monitor_launched_pod(launcher, pod)
+        pod_phase = remote_pod.status.phase if hasattr(remote_pod, 'status') 
else None
+        if pod_phase != PodStatus.SUCCEEDED:
+            if self.log_events_on_failure:
+                with _suppress(Exception):
+                    for event in self.launcher.read_pod_events(pod).items:
+                        self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
+            if not self.is_delete_operator_pod:
+                with _suppress(Exception):
+                    self.patch_already_checked(pod)
+            raise AirflowException(f'Pod {pod and pod.metadata.name} returned 
a failure: {remote_pod}')
+
+    def process_pod_deletion(self, pod):
+        if self.is_delete_operator_pod:
+            self.log.info("deleting pod: %s", pod.metadata.name)
+            self.launcher.delete_pod(pod)
         else:
-            log_line += f"creating pod with labels {labels} and launcher 
{launcher}"
-            self.log.info(log_line)
-            final_state, remote_pod, result = 
self.create_new_pod_for_operator(labels, launcher)
-        return final_state, remote_pod, result
+            self.log.info("skipping deleting pod: %s", pod.metadata.name)
 
-    @staticmethod
-    def _get_pod_identifying_label_string(labels) -> str:
+    def _get_pod_identifying_label_string(self, labels) -> str:
         label_strings = [
             f'{label_id}={label}' for label_id, label in 
sorted(labels.items()) if label_id != 'try_number'
         ]
-        return ','.join(label_strings) + ',already_checked!=True'
-
-    @staticmethod
-    def _try_numbers_match(context, pod) -> bool:
-        return pod.metadata.labels['try_number'] == context['ti'].try_number
+        return ','.join(label_strings) + f',{self.POD_CHECKED_KEY}!=True'
+
+    def _compare_try_numbers(self, context, pod):
+        tries_match = pod.metadata.labels['try_number'] == 
context['ti'].try_number
+        self.log.info(
+            "found a running pod with labels %s %s try_number.",
+            pod.metadata.labels,
+            "and the same" if tries_match else "but a different",
+        )

Review comment:
       update: i inlined the comparison and simplified it a bit just plainly 
stating the facts instead of the flowery expression
   
   there was also already a `Found matching pod` log.info at call site and i 
combined them.
   
   ```python
               self.log.info("Found matching pod %s with labels %s", 
pod.metadata.name, pod.metadata.labels)
               self.log.info("`try_number` of task_instance: %s", 
context['ti'].try_number)
               self.log.info("`try_number` of pod: %s", 
pod.metadata.labels['try_number'])
   ```
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to