This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8985df0  Monitor pods by labels instead of names (#6377)
8985df0 is described below

commit 8985df0bfcb5f2b2cd69a21b9814021f9f8ce953
Author: Daniel Imberman <[email protected]>
AuthorDate: Sat May 16 14:13:58 2020 -0700

    Monitor pods by labels instead of names (#6377)
    
    * Monitor k8sPodOperator pods by labels
    
    To prevent situations where the scheduler starts a
    second k8sPodOperator pod after a restart, we now check
    for existing pods using kubernetes labels
    
    * Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
    
    Co-authored-by: Kaxil Naik <[email protected]>
    
    * Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
    
    Co-authored-by: Kaxil Naik <[email protected]>
    
    * add docs
    
    * Update airflow/kubernetes/pod_launcher.py
    
    Co-authored-by: Kaxil Naik <[email protected]>
    
    Co-authored-by: Daniel Imberman <[email protected]>
    Co-authored-by: Kaxil Naik <[email protected]>
---
 airflow/executors/kubernetes_executor.py           |  34 +---
 airflow/kubernetes/pod_generator.py                |  23 +++
 airflow/kubernetes/pod_launcher.py                 |  16 +-
 .../cncf/kubernetes/operators/kubernetes_pod.py    | 226 ++++++++++++++-------
 tests/executors/test_kubernetes_executor.py        |   9 +-
 .../kubernetes/test_kubernetes_pod_operator.py     | 195 ++++++++++++------
 6 files changed, 337 insertions(+), 166 deletions(-)

diff --git a/airflow/executors/kubernetes_executor.py 
b/airflow/executors/kubernetes_executor.py
index 74367c7..841a8e2 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -23,10 +23,8 @@ KubernetesExecutor
 """
 import base64
 import datetime
-import hashlib
 import json
 import multiprocessing
-import re
 import time
 from queue import Empty, Queue  # pylint: disable=unused-import
 from typing import Any, Dict, Optional, Tuple, Union
@@ -42,6 +40,7 @@ from airflow import settings
 from airflow.configuration import conf
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, 
CommandType
+from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
 from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, PodGenerator
 from airflow.kubernetes.pod_launcher import PodLauncher
@@ -462,8 +461,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
             namespace=self.namespace,
             worker_uuid=self.worker_uuid,
             pod_id=self._create_pod_id(dag_id, task_id),
-            dag_id=self._make_safe_label_value(dag_id),
-            task_id=self._make_safe_label_value(task_id),
+            dag_id=pod_generator.make_safe_label_value(dag_id),
+            task_id=pod_generator.make_safe_label_value(task_id),
             try_number=try_number,
             date=self._datetime_to_label_safe_datestring(execution_date),
             command=command,
@@ -556,25 +555,6 @@ class AirflowKubernetesScheduler(LoggingMixin):
         return safe_pod_id
 
     @staticmethod
-    def _make_safe_label_value(string: str) -> str:
-        """
-        Valid label values must be 63 characters or less and must be empty or 
begin and
-        end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), 
underscores (_),
-        dots (.), and alphanumerics between.
-
-        If the label value is then greater than 63 chars once made safe, or 
differs in any
-        way from the original value sent to this function, then we need to 
truncate to
-        53chars, and append it with a unique hash.
-        """
-        safe_label = 
re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string)
-
-        if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
-            safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
-            safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" 
+ safe_hash
-
-        return safe_label
-
-    @staticmethod
     def _create_pod_id(dag_id: str, task_id: str) -> str:
         safe_dag_id = 
AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
             dag_id)
@@ -657,8 +637,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
             )
             for task in tasks:
                 if (
-                    self._make_safe_label_value(task.dag_id) == dag_id and
-                    self._make_safe_label_value(task.task_id) == task_id and
+                    pod_generator.make_safe_label_value(task.dag_id) == dag_id 
and
+                    pod_generator.make_safe_label_value(task.task_id) == 
task_id and
                     task.execution_date == ex_time
                 ):
                     self.log.info(
@@ -744,8 +724,8 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
             # pylint: disable=protected-access
             dict_string = (
                 
"dag_id={},task_id={},execution_date={},airflow-worker={}".format(
-                    
AirflowKubernetesScheduler._make_safe_label_value(task.dag_id),
-                    
AirflowKubernetesScheduler._make_safe_label_value(task.task_id),
+                    pod_generator.make_safe_label_value(task.dag_id),
+                    pod_generator.make_safe_label_value(task.task_id),
                     
AirflowKubernetesScheduler._datetime_to_label_safe_datestring(
                         task.execution_date
                     ),
diff --git a/airflow/kubernetes/pod_generator.py 
b/airflow/kubernetes/pod_generator.py
index 294565d..8c3a091 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -22,8 +22,10 @@ is supported and no serialization need be written.
 """
 
 import copy
+import hashlib
 import inspect
 import os
+import re
 import uuid
 from functools import reduce
 from typing import Dict, List, Optional, Union
@@ -37,6 +39,8 @@ from airflow.version import version as airflow_version
 
 MAX_POD_ID_LEN = 253
 
+MAX_LABEL_LEN = 63
+
 
 class PodDefaults:
     """
@@ -66,6 +70,25 @@ class PodDefaults:
     )
 
 
+def make_safe_label_value(string):
+    """
+    Valid label values must be 63 characters or less and must be empty or 
begin and
+    end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), 
underscores (_),
+    dots (.), and alphanumerics between.
+
+    If the label value is greater than 63 chars once made safe, or differs in 
any
+    way from the original value sent to this function, then we need to 
truncate to
+    53 chars, and append it with a unique hash.
+    """
+    safe_label = re.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", 
"", string)
+
+    if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
+        safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
+        safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + 
safe_hash
+
+    return safe_label
+
+
 class PodGenerator:
     """
     Contains Kubernetes Airflow Worker configuration logic
diff --git a/airflow/kubernetes/pod_launcher.py 
b/airflow/kubernetes/pod_launcher.py
index 32efb1d..45cc085 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -93,17 +93,15 @@ class PodLauncher(LoggingMixin):
             if e.status != 404:
                 raise
 
-    def run_pod(
+    def start_pod(
             self,
             pod: V1Pod,
-            startup_timeout: int = 120,
-            get_logs: bool = True) -> Tuple[State, Optional[str]]:
+            startup_timeout: int = 120):
         """
         Launches the pod synchronously and waits for completion.
 
         :param pod:
         :param startup_timeout: Timeout for startup of the pod (if pod is 
pending for too long, fails task)
-        :param get_logs:  whether to query k8s for logs
         :return:
         """
         resp = self.run_pod_async(pod)
@@ -116,9 +114,15 @@ class PodLauncher(LoggingMixin):
                 time.sleep(1)
             self.log.debug('Pod not yet started')
 
-        return self._monitor_pod(pod, get_logs)
+    def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, 
Optional[str]]:
+        """
+        Monitors a pod and returns the final state
 
-    def _monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, 
Optional[str]]:
+        :param pod: pod spec that will be monitored
+        :type pod : V1Pod
+        :param get_logs: whether to read the logs locally
+        :return:  Tuple[State, Optional[str]]
+        """
         if get_logs:
             logs = self.read_pod_logs(pod)
             for line in logs:
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index c2ccd74..bc5679d 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -16,7 +16,7 @@
 # under the License.
 """Executes task in a Kubernetes POD"""
 import re
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
 
 import kubernetes.client.models as k8s
 
@@ -74,6 +74,8 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
     :param cluster_context: context that points to kubernetes cluster.
         Ignored when in_cluster is True. If None, current-context is used.
     :type cluster_context: str
+    :param reattach_on_restart: if the scheduler dies while the pod is 
running, reattach and monitor
+    :type reattach_on_restart: bool
     :param labels: labels to apply to the Pod.
     :type labels: dict
     :param startup_timeout_seconds: timeout in seconds to startup the pod.
@@ -156,6 +158,7 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
                  in_cluster: Optional[bool] = None,
                  cluster_context: Optional[str] = None,
                  labels: Optional[Dict] = None,
+                 reattach_on_restart: bool = True,
                  startup_timeout_seconds: int = 120,
                  get_logs: bool = True,
                  image_pull_policy: str = 'IfNotPresent',
@@ -201,6 +204,7 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
         self.secrets = secrets or []
         self.in_cluster = in_cluster
         self.cluster_context = cluster_context
+        self.reattach_on_restart = reattach_on_restart
         self.get_logs = get_logs
         self.image_pull_policy = image_pull_policy
         self.node_selectors = node_selectors or {}
@@ -225,7 +229,31 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
         self.pod_template_file = pod_template_file
         self.name = self._set_name(name)
 
-    def execute(self, context):
+    @staticmethod
+    def create_labels_for_pod(context) -> dict:
+        """
+        Generate labels for the pod to track the pod in case of Operator crash
+
+        :param context: task context provided by airflow DAG
+        :return: dict
+        """
+        labels = {
+            'dag_id': context['dag'].dag_id,
+            'task_id': context['task'].task_id,
+            'execution_date': context['ts'],
+            'try_number': context['ti'].try_number,
+        }
+        # In the case of sub dags this is just useful
+        if context['dag'].is_subdag:
+            labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
+        # Ensure that label is valid for Kube,
+        # and if not truncate/remove invalid chars and replace with short hash.
+        for label_id, label in labels.items():
+            safe_label = pod_generator.make_safe_label_value(str(label))
+            labels[label_id] = safe_label
+        return labels
+
+    def execute(self, context) -> Optional[str]:
         try:
             if self.in_cluster is not None:
                 client = 
kube_client.get_kube_client(in_cluster=self.in_cluster,
@@ -235,85 +263,49 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
                 client = 
kube_client.get_kube_client(cluster_context=self.cluster_context,
                                                      
config_file=self.config_file)
 
-            if not (self.full_pod_spec or self.pod_template_file):
-                # Add Airflow Version to the label
-                # And a label to identify that pod is launched by 
KubernetesPodOperator
-                self.labels.update(
-                    {
-                        'airflow_version': airflow_version.replace('+', '-'),
-                        'kubernetes_pod_operator': 'True',
-                    }
-                )
-            pod = pod_generator.PodGenerator(
-                image=self.image,
-                namespace=self.namespace,
-                cmds=self.cmds,
-                args=self.arguments,
-                labels=self.labels,
-                name=self.name,
-                envs=self.env_vars,
-                extract_xcom=self.do_xcom_push,
-                image_pull_policy=self.image_pull_policy,
-                node_selectors=self.node_selectors,
-                annotations=self.annotations,
-                affinity=self.affinity,
-                image_pull_secrets=self.image_pull_secrets,
-                service_account_name=self.service_account_name,
-                hostnetwork=self.hostnetwork,
-                tolerations=self.tolerations,
-                configmaps=self.configmaps,
-                security_context=self.security_context,
-                dnspolicy=self.dnspolicy,
-                schedulername=self.schedulername,
-                init_containers=self.init_containers,
-                restart_policy='Never',
-                priority_class_name=self.priority_class_name,
-                pod_template_file=self.pod_template_file,
-                pod=self.full_pod_spec,
-            ).gen_pod()
+            # Add combination of labels to uniquely identify a running pod
+            labels = self.create_labels_for_pod(context)
 
-            pod = append_to_pod(
-                pod,
-                self.pod_runtime_info_envs +
-                self.ports +
-                self.resources +
-                self.secrets +
-                self.volumes +
-                self.volume_mounts
-            )
+            label_selector = self._get_pod_identifying_label_string(labels)
 
-            self.pod = pod
+            pod_list = client.list_namespaced_pod(self.namespace, 
label_selector=label_selector)
 
-            launcher = pod_launcher.PodLauncher(kube_client=client,
-                                                extract_xcom=self.do_xcom_push)
+            if len(pod_list.items) > 1:
+                raise AirflowException(
+                    'More than one pod running with labels: '
+                    '{label_selector}'.format(label_selector=label_selector))
 
-            try:
-                (final_state, result) = launcher.run_pod(
-                    pod,
-                    startup_timeout=self.startup_timeout_seconds,
-                    get_logs=self.get_logs)
-            except AirflowException:
-                if self.log_events_on_failure:
-                    for event in launcher.read_pod_events(pod).items:
-                        self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
-                raise
-            finally:
-                if self.is_delete_operator_pod:
-                    launcher.delete_pod(pod)
+            launcher = pod_launcher.PodLauncher(kube_client=client, 
extract_xcom=self.do_xcom_push)
 
+            if len(pod_list.items) == 1 and \
+                    self._try_numbers_do_not_match(context, pod_list.items[0]) 
and \
+                    self.reattach_on_restart:
+                self.log.info("found a running pod with labels %s but a 
different try_number"
+                              "Will attach to this pod and monitor instead of 
starting new one", labels)
+                final_state, _, result = 
self.create_new_pod_for_operator(labels, launcher)
+            elif len(pod_list.items) == 1:
+                self.log.info("found a running pod with labels %s."
+                              "Will monitor this pod instead of starting new 
one", labels)
+                final_state, result = self.monitor_launched_pod(launcher, 
pod_list[0])
+            else:
+                final_state, _, result = 
self.create_new_pod_for_operator(labels, launcher)
             if final_state != State.SUCCESS:
-                if self.log_events_on_failure:
-                    for event in launcher.read_pod_events(pod).items:
-                        self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
                 raise AirflowException(
-                    'Pod returned a failure: {state}'.format(state=final_state)
-                )
-
+                    'Pod returned a failure: 
{state}'.format(state=final_state))
             return result
         except AirflowException as ex:
             raise AirflowException('Pod Launching failed: 
{error}'.format(error=ex))
 
     @staticmethod
+    def _get_pod_identifying_label_string(labels):
+        filtered_labels = {label_id: label for label_id, label in 
labels.items() if label_id != 'try_number'}
+        return ','.join([label_id + '=' + label for label_id, label in 
sorted(filtered_labels.items())])
+
+    @staticmethod
+    def _try_numbers_do_not_match(context, pod):
+        return pod.metadata.labels['try_number'] != context['ti'].try_number
+
+    @staticmethod
     def _set_resources(resources):
         if not resources:
             return []
@@ -324,3 +316,99 @@ class KubernetesPodOperator(BaseOperator):  # pylint: 
disable=too-many-instance-
             return None
         validate_key(name, max_length=220)
         return re.sub(r'[^a-z0-9.-]+', '-', name.lower())
+
+    def create_new_pod_for_operator(self, labels, launcher) -> Tuple[State, 
k8s.V1Pod, Optional[str]]:
+        """
+        Creates a new pod and monitors for duration of task
+
+        @param labels: labels used to track pod
+        @param launcher: pod launcher that will manage launching and 
monitoring pods
+        @return:
+        """
+        if not (self.full_pod_spec or self.pod_template_file):
+            # Add Airflow Version to the label
+            # And a label to identify that pod is launched by 
KubernetesPodOperator
+            self.labels.update(
+                {
+                    'airflow_version': airflow_version.replace('+', '-'),
+                    'kubernetes_pod_operator': 'True',
+                }
+            )
+            self.labels.update(labels)
+        pod = pod_generator.PodGenerator(
+            image=self.image,
+            namespace=self.namespace,
+            cmds=self.cmds,
+            args=self.arguments,
+            labels=self.labels,
+            name=self.name,
+            envs=self.env_vars,
+            extract_xcom=self.do_xcom_push,
+            image_pull_policy=self.image_pull_policy,
+            node_selectors=self.node_selectors,
+            annotations=self.annotations,
+            affinity=self.affinity,
+            image_pull_secrets=self.image_pull_secrets,
+            service_account_name=self.service_account_name,
+            hostnetwork=self.hostnetwork,
+            tolerations=self.tolerations,
+            configmaps=self.configmaps,
+            security_context=self.security_context,
+            dnspolicy=self.dnspolicy,
+            schedulername=self.schedulername,
+            init_containers=self.init_containers,
+            restart_policy='Never',
+            priority_class_name=self.priority_class_name,
+            pod_template_file=self.pod_template_file,
+            pod=self.full_pod_spec,
+        ).gen_pod()
+
+        # noinspection PyTypeChecker
+        pod = append_to_pod(
+            pod,
+            self.pod_runtime_info_envs +  # type: ignore
+            self.ports +  # type: ignore
+            self.resources +  # type: ignore
+            self.secrets +  # type: ignore
+            self.volumes +  # type: ignore
+            self.volume_mounts  # type: ignore
+        )
+
+        self.pod = pod
+
+        try:
+            launcher.start_pod(
+                pod,
+                startup_timeout=self.startup_timeout_seconds)
+            final_state, result = launcher.monitor_pod(pod=pod, 
get_logs=self.get_logs)
+        except AirflowException:
+            if self.log_events_on_failure:
+                for event in launcher.read_pod_events(pod).items:
+                    self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
+            raise
+        finally:
+            if self.is_delete_operator_pod:
+                launcher.delete_pod(pod)
+        return final_state, pod, result
+
+    def monitor_launched_pod(self, launcher, pod) -> Tuple[State, 
Optional[str]]:
+        """
+        Montitors a pod to completion that was created by a previous 
KubernetesPodOperator
+
+        @param launcher: pod launcher that will manage launching and 
monitoring pods
+        :param pod: podspec used to find pod using k8s API
+        :return:
+        """
+        try:
+            (final_state, result) = launcher.monitor_pod(pod, 
get_logs=self.get_logs)
+        finally:
+            if self.is_delete_operator_pod:
+                launcher.delete_pod(pod)
+        if final_state != State.SUCCESS:
+            if self.log_events_on_failure:
+                for event in launcher.read_pod_events(pod).items:
+                    self.log.error("Pod Event: %s - %s", event.reason, 
event.message)
+            raise AirflowException(
+                'Pod returned a failure: {state}'.format(state=final_state)
+            )
+        return final_state, result
diff --git a/tests/executors/test_kubernetes_executor.py 
b/tests/executors/test_kubernetes_executor.py
index df2c4fc..f9002a1 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -33,6 +33,7 @@ try:
     from airflow.executors.kubernetes_executor import 
AirflowKubernetesScheduler
     from airflow.executors.kubernetes_executor import KubernetesExecutor
     from airflow.executors.kubernetes_executor import KubeConfig
+    from airflow.kubernetes import pod_generator
     from airflow.kubernetes.pod_generator import PodGenerator
     from airflow.utils.state import State
 except ImportError:
@@ -91,19 +92,19 @@ class TestAirflowKubernetesScheduler(unittest.TestCase):
 
     def test_make_safe_label_value(self):
         for dag_id, task_id in self._cases():
-            safe_dag_id = 
AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+            safe_dag_id = pod_generator.make_safe_label_value(dag_id)
             self.assertTrue(self._is_safe_label_value(safe_dag_id))
-            safe_task_id = 
AirflowKubernetesScheduler._make_safe_label_value(task_id)
+            safe_task_id = pod_generator.make_safe_label_value(task_id)
             self.assertTrue(self._is_safe_label_value(safe_task_id))
             dag_id = "my_dag_id"
             self.assertEqual(
                 dag_id,
-                AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+                pod_generator.make_safe_label_value(dag_id)
             )
             dag_id = "my_dag_id_" + "a" * 64
             self.assertEqual(
                 "my_dag_id_" + "a" * 43 + "-0ce114c45",
-                AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+                pod_generator.make_safe_label_value(dag_id)
             )
 
     @unittest.skipIf(AirflowKubernetesScheduler is None,
diff --git a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py 
b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
index 2d6f0ac..98d47ac 100644
--- a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
+++ b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
@@ -23,18 +23,22 @@ from unittest import mock
 from unittest.mock import ANY
 
 import kubernetes.client.models as k8s
+import pendulum
 import pytest
 from kubernetes.client.api_client import ApiClient
 from kubernetes.client.rest import ApiException
 
 from airflow.exceptions import AirflowException
+from airflow.kubernetes import kube_client
 from airflow.kubernetes.pod import Port
 from airflow.kubernetes.pod_generator import PodDefaults
 from airflow.kubernetes.pod_launcher import PodLauncher
 from airflow.kubernetes.secret import Secret
 from airflow.kubernetes.volume import Volume
 from airflow.kubernetes.volume_mount import VolumeMount
+from airflow.models import DAG, TaskInstance
 from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import 
KubernetesPodOperator
+from airflow.utils import timezone
 from airflow.version import version as airflow_version
 
 
@@ -53,8 +57,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
                 'annotations': {},
                 'labels': {
                     'foo': 'bar', 'kubernetes_pod_operator': 'True',
-                    'airflow_version': airflow_version.replace('+', '-')
-                }
+                    'airflow_version': airflow_version.replace('+', '-'),
+                    'execution_date': '2016-01-01T0100000100-a2f50a31f',
+                    'dag_id': 'dag',
+                    'task_id': 'task',
+                    'try_number': '1'},
             },
             'spec': {
                 'affinity': {},
@@ -81,6 +88,23 @@ class TestKubernetesPodOperator(unittest.TestCase):
             }
         }
 
+    def tearDown(self) -> None:
+        client = kube_client.get_kube_client(in_cluster=False)
+        client.delete_collection_namespaced_pod(namespace="default")
+
+    def create_context(self, task):
+        dag = DAG(dag_id="dag")
+        tzinfo = pendulum.timezone("Europe/Amsterdam")
+        execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+        task_instance = TaskInstance(task=task,
+                                     execution_date=execution_date)
+        return {
+            "dag": dag,
+            "ts": execution_date.isoformat(),
+            "task": task,
+            "ti": task_instance,
+        }
+
     def test_do_xcom_push_defaults_false(self):
         new_config_path = '/tmp/kube_config'
         old_config_path = os.path.expanduser('~/.kube/config')
@@ -111,19 +135,21 @@ class TestKubernetesPodOperator(unittest.TestCase):
             cmds=["bash", "-cx"],
             arguments=["echo 10"],
             labels={"foo": "bar"},
-            name="test",
+            name="test1",
             task_id="task",
             in_cluster=False,
             do_xcom_push=False,
             config_file=new_config_path,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.assertEqual(self.expected_pod, actual_pod)
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_config_path(self, client_mock, launcher_mock):
+    def test_config_path(self, client_mock, monitor_mock, start_mock):  # 
pylint: disable=unused-argument
         from airflow.utils.state import State
 
         file_path = "/tmp/fake_file"
@@ -140,17 +166,20 @@ class TestKubernetesPodOperator(unittest.TestCase):
             config_file=file_path,
             cluster_context='default',
         )
-        launcher_mock.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        monitor_mock.return_value = (State.SUCCESS, None)
+        client_mock.list_namespaced_pod.return_value = []
+        context = self.create_context(k)
+        k.execute(context=context)
         client_mock.assert_called_once_with(
             in_cluster=False,
             cluster_context='default',
             config_file=file_path,
         )
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_image_pull_secrets_correctly_set(self, mock_client, 
launcher_mock):
+    def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock, 
start_mock):
         from airflow.utils.state import State
 
         fake_pull_secrets = "fakeSecret"
@@ -167,17 +196,24 @@ class TestKubernetesPodOperator(unittest.TestCase):
             image_pull_secrets=fake_pull_secrets,
             cluster_context='default',
         )
-        launcher_mock.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        monitor_mock.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context=context)
         self.assertEqual(
-            launcher_mock.call_args[0][0].spec.image_pull_secrets,
+            start_mock.call_args[0][0].spec.image_pull_secrets,
             [k8s.V1LocalObjectReference(name=fake_pull_secrets)]
         )
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_pod_delete_even_on_launcher_error(self, mock_client, 
delete_pod_mock, run_pod_mock):
+    def test_pod_delete_even_on_launcher_error(
+            self,
+            mock_client,
+            delete_pod_mock,
+            monitor_pod_mock,
+            start_pod_mock):  # pylint: disable=unused-argument
         k = KubernetesPodOperator(
             namespace='default',
             image="ubuntu:16.04",
@@ -191,9 +227,10 @@ class TestKubernetesPodOperator(unittest.TestCase):
             cluster_context='default',
             is_delete_operator_pod=True,
         )
-        run_pod_mock.side_effect = AirflowException('fake failure')
+        monitor_pod_mock.side_effect = AirflowException('fake failure')
         with self.assertRaises(AirflowException):
-            k.execute(None)
+            context = self.create_context(k)
+            k.execute(context=context)
         assert delete_pod_mock.called
 
     def test_working_pod(self):
@@ -208,9 +245,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
             in_cluster=False,
             do_xcom_push=False,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
-        self.assertEqual(self.expected_pod, actual_pod)
+        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+        self.assertEqual(self.expected_pod['metadata']['labels'], 
actual_pod['metadata']['labels'])
 
     def test_delete_operator_pod(self):
         k = KubernetesPodOperator(
@@ -225,9 +264,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             is_delete_operator_pod=True,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
-        self.assertEqual(self.expected_pod, actual_pod)
+        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+        self.assertEqual(self.expected_pod['metadata']['labels'], 
actual_pod['metadata']['labels'])
 
     def test_pod_hostnetwork(self):
         k = KubernetesPodOperator(
@@ -242,10 +283,12 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             hostnetwork=True,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['hostNetwork'] = True
-        self.assertEqual(self.expected_pod, actual_pod)
+        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+        self.assertEqual(self.expected_pod['metadata']['labels'], 
actual_pod['metadata']['labels'])
 
     def test_pod_dnspolicy(self):
         dns_policy = "ClusterFirstWithHostNet"
@@ -262,11 +305,13 @@ class TestKubernetesPodOperator(unittest.TestCase):
             hostnetwork=True,
             dnspolicy=dns_policy
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['hostNetwork'] = True
         self.expected_pod['spec']['dnsPolicy'] = dns_policy
-        self.assertEqual(self.expected_pod, actual_pod)
+        self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+        self.assertEqual(self.expected_pod['metadata']['labels'], 
actual_pod['metadata']['labels'])
 
     def test_pod_schedulername(self):
         scheduler_name = "default-scheduler"
@@ -282,7 +327,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             schedulername=scheduler_name
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['schedulerName'] = scheduler_name
         self.assertEqual(self.expected_pod, actual_pod)
@@ -303,7 +349,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             node_selectors=node_selectors,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['nodeSelector'] = node_selectors
         self.assertEqual(self.expected_pod, actual_pod)
@@ -329,7 +376,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             resources=resources,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['containers'][0]['resources'] = {
             'requests': {
@@ -376,7 +424,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             affinity=affinity,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context=context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['affinity'] = affinity
         self.assertEqual(self.expected_pod, actual_pod)
@@ -396,7 +445,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             ports=[port],
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context=context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['containers'][0]['ports'] = [{
             'name': 'http',
@@ -432,7 +482,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
                 in_cluster=False,
                 do_xcom_push=False,
             )
-            k.execute(None)
+            context = self.create_context(k)
+            k.execute(context=context)
             mock_logger.info.assert_any_call(b"retrieved from mount\n")
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod['spec']['containers'][0]['args'] = args
@@ -467,7 +518,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             security_context=security_context,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['securityContext'] = security_context
         self.assertEqual(self.expected_pod, actual_pod)
@@ -491,7 +543,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             security_context=security_context,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['securityContext'] = security_context
         self.assertEqual(self.expected_pod, actual_pod)
@@ -515,7 +568,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
             security_context=security_context,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['securityContext'] = security_context
         self.assertEqual(self.expected_pod, actual_pod)
@@ -535,7 +589,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             startup_timeout_seconds=5,
         )
         with self.assertRaises(AirflowException):
-            k.execute(None)
+            context = self.create_context(k)
+            k.execute(context)
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod['spec']['containers'][0]['image'] = 
bad_image_name
             self.assertEqual(self.expected_pod, actual_pod)
@@ -556,7 +611,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             service_account_name=bad_service_account_name,
         )
         with self.assertRaises(ApiException):
-            k.execute(None)
+            context = self.create_context(k)
+            k.execute(context)
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod['spec']['serviceAccountName'] = 
bad_service_account_name
             self.assertEqual(self.expected_pod, actual_pod)
@@ -578,7 +634,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
         )
         with self.assertRaises(AirflowException):
-            k.execute(None)
+            context = self.create_context(k)
+            k.execute(context)
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod['spec']['containers'][0]['args'] = 
bad_internal_command
             self.assertEqual(self.expected_pod, actual_pod)
@@ -597,7 +654,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             in_cluster=False,
             do_xcom_push=True,
         )
-        self.assertEqual(k.execute(None), json.loads(return_value))
+        context = self.create_context(k)
+        self.assertEqual(k.execute(context), json.loads(return_value))
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
         volume_mount = 
self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
@@ -608,9 +666,10 @@ class TestKubernetesPodOperator(unittest.TestCase):
         self.expected_pod['spec']['containers'].append(container)
         self.assertEqual(self.expected_pod, actual_pod)
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_envs_from_configmaps(self, mock_client, mock_launcher):
+    def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
         # GIVEN
         from airflow.utils.state import State
 
@@ -629,18 +688,20 @@ class TestKubernetesPodOperator(unittest.TestCase):
             configmaps=[configmap],
         )
         # THEN
-        mock_launcher.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        mock_monitor.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context)
         self.assertEqual(
-            mock_launcher.call_args[0][0].spec.containers[0].env_from,
+            mock_start.call_args[0][0].spec.containers[0].env_from,
             [k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
                 name=configmap
             ))]
         )
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_envs_from_secrets(self, mock_client, launcher_mock):
+    def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
         # GIVEN
         from airflow.utils.state import State
         secret_ref = 'secret_name'
@@ -659,10 +720,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
             do_xcom_push=False,
         )
         # THEN
-        launcher_mock.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        monitor_mock.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context)
         self.assertEqual(
-            launcher_mock.call_args[0][0].spec.containers[0].env_from,
+            start_mock.call_args[0][0].spec.containers[0].env_from,
             [k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
                 name=secret_ref
             ))]
@@ -696,9 +758,9 @@ class TestKubernetesPodOperator(unittest.TestCase):
 
         volume_config = {
             'persistentVolumeClaim':
-            {
-                'claimName': 'test-volume'
-            }
+                {
+                    'claimName': 'test-volume'
+                }
         }
         volume = Volume(name='test-volume', configs=volume_config)
 
@@ -734,7 +796,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
             in_cluster=False,
             do_xcom_push=False,
         )
-        k.execute(None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['initContainers'] = [expected_init_container]
         self.expected_pod['spec']['volumes'] = [{
@@ -745,17 +808,23 @@ class TestKubernetesPodOperator(unittest.TestCase):
         }]
         self.assertEqual(self.expected_pod, actual_pod)
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_pod_template_file(self, mock_client, launcher_mock):
+    def test_pod_template_file(
+            self,
+            mock_client,
+            monitor_mock,
+            start_mock):  # pylint: disable=unused-argument
         from airflow.utils.state import State
         k = KubernetesPodOperator(
             task_id='task',
             pod_template_file='tests/kubernetes/pod.yaml',
             do_xcom_push=True
         )
-        launcher_mock.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        monitor_mock.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.assertEqual({
             'apiVersion': 'v1',
@@ -791,9 +860,14 @@ class TestKubernetesPodOperator(unittest.TestCase):
             }
         }, actual_pod)
 
-    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+    @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
     @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
-    def test_pod_priority_class_name(self, mock_client, launcher_mock):
+    def test_pod_priority_class_name(
+            self,
+            mock_client,
+            monitor_mock,
+            start_mock):  # pylint: disable=unused-argument
         """Test ability to assign priorityClassName to pod
 
         """
@@ -813,8 +887,9 @@ class TestKubernetesPodOperator(unittest.TestCase):
             priority_class_name=priority_class_name,
         )
 
-        launcher_mock.return_value = (State.SUCCESS, None)
-        k.execute(None)
+        monitor_mock.return_value = (State.SUCCESS, None)
+        context = self.create_context(k)
+        k.execute(context)
         actual_pod = self.api_client.sanitize_for_serialization(k.pod)
         self.expected_pod['spec']['priorityClassName'] = priority_class_name
         self.assertEqual(self.expected_pod, actual_pod)

Reply via email to