This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 8985df0 Monitor pods by labels instead of names (#6377)
8985df0 is described below
commit 8985df0bfcb5f2b2cd69a21b9814021f9f8ce953
Author: Daniel Imberman <[email protected]>
AuthorDate: Sat May 16 14:13:58 2020 -0700
Monitor pods by labels instead of names (#6377)
* Monitor k8sPodOperator pods by labels
To prevent situations where the scheduler starts a
second k8sPodOperator pod after a restart, we now check
for existing pods using kubernetes labels
* Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Co-authored-by: Kaxil Naik <[email protected]>
* Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Co-authored-by: Kaxil Naik <[email protected]>
* add docs
* Update airflow/kubernetes/pod_launcher.py
Co-authored-by: Kaxil Naik <[email protected]>
Co-authored-by: Daniel Imberman <[email protected]>
Co-authored-by: Kaxil Naik <[email protected]>
---
airflow/executors/kubernetes_executor.py | 34 +---
airflow/kubernetes/pod_generator.py | 23 +++
airflow/kubernetes/pod_launcher.py | 16 +-
.../cncf/kubernetes/operators/kubernetes_pod.py | 226 ++++++++++++++-------
tests/executors/test_kubernetes_executor.py | 9 +-
.../kubernetes/test_kubernetes_pod_operator.py | 195 ++++++++++++------
6 files changed, 337 insertions(+), 166 deletions(-)
diff --git a/airflow/executors/kubernetes_executor.py
b/airflow/executors/kubernetes_executor.py
index 74367c7..841a8e2 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -23,10 +23,8 @@ KubernetesExecutor
"""
import base64
import datetime
-import hashlib
import json
import multiprocessing
-import re
import time
from queue import Empty, Queue # pylint: disable=unused-import
from typing import Any, Dict, Optional, Tuple, Union
@@ -42,6 +40,7 @@ from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor,
CommandType
+from airflow.kubernetes import pod_generator
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import MAX_POD_ID_LEN, PodGenerator
from airflow.kubernetes.pod_launcher import PodLauncher
@@ -462,8 +461,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
namespace=self.namespace,
worker_uuid=self.worker_uuid,
pod_id=self._create_pod_id(dag_id, task_id),
- dag_id=self._make_safe_label_value(dag_id),
- task_id=self._make_safe_label_value(task_id),
+ dag_id=pod_generator.make_safe_label_value(dag_id),
+ task_id=pod_generator.make_safe_label_value(task_id),
try_number=try_number,
date=self._datetime_to_label_safe_datestring(execution_date),
command=command,
@@ -556,25 +555,6 @@ class AirflowKubernetesScheduler(LoggingMixin):
return safe_pod_id
@staticmethod
- def _make_safe_label_value(string: str) -> str:
- """
- Valid label values must be 63 characters or less and must be empty or
begin and
- end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-),
underscores (_),
- dots (.), and alphanumerics between.
-
- If the label value is then greater than 63 chars once made safe, or
differs in any
- way from the original value sent to this function, then we need to
truncate to
- 53chars, and append it with a unique hash.
- """
- safe_label =
re.sub(r'^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$', '', string)
-
- if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
- safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
- safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-"
+ safe_hash
-
- return safe_label
-
- @staticmethod
def _create_pod_id(dag_id: str, task_id: str) -> str:
safe_dag_id =
AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars(
dag_id)
@@ -657,8 +637,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
)
for task in tasks:
if (
- self._make_safe_label_value(task.dag_id) == dag_id and
- self._make_safe_label_value(task.task_id) == task_id and
+ pod_generator.make_safe_label_value(task.dag_id) == dag_id
and
+ pod_generator.make_safe_label_value(task.task_id) ==
task_id and
task.execution_date == ex_time
):
self.log.info(
@@ -744,8 +724,8 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
# pylint: disable=protected-access
dict_string = (
"dag_id={},task_id={},execution_date={},airflow-worker={}".format(
-
AirflowKubernetesScheduler._make_safe_label_value(task.dag_id),
-
AirflowKubernetesScheduler._make_safe_label_value(task.task_id),
+ pod_generator.make_safe_label_value(task.dag_id),
+ pod_generator.make_safe_label_value(task.task_id),
AirflowKubernetesScheduler._datetime_to_label_safe_datestring(
task.execution_date
),
diff --git a/airflow/kubernetes/pod_generator.py
b/airflow/kubernetes/pod_generator.py
index 294565d..8c3a091 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -22,8 +22,10 @@ is supported and no serialization need be written.
"""
import copy
+import hashlib
import inspect
import os
+import re
import uuid
from functools import reduce
from typing import Dict, List, Optional, Union
@@ -37,6 +39,8 @@ from airflow.version import version as airflow_version
MAX_POD_ID_LEN = 253
+MAX_LABEL_LEN = 63
+
class PodDefaults:
"""
@@ -66,6 +70,25 @@ class PodDefaults:
)
+def make_safe_label_value(string):
+ """
+ Valid label values must be 63 characters or less and must be empty or
begin and
+ end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-),
underscores (_),
+ dots (.), and alphanumerics between.
+
+ If the label value is greater than 63 chars once made safe, or differs in
any
+ way from the original value sent to this function, then we need to
truncate to
+ 53 chars, and append it with a unique hash.
+ """
+ safe_label = re.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$",
"", string)
+
+ if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
+ safe_hash = hashlib.md5(string.encode()).hexdigest()[:9]
+ safe_label = safe_label[:MAX_LABEL_LEN - len(safe_hash) - 1] + "-" +
safe_hash
+
+ return safe_label
+
+
class PodGenerator:
"""
Contains Kubernetes Airflow Worker configuration logic
diff --git a/airflow/kubernetes/pod_launcher.py
b/airflow/kubernetes/pod_launcher.py
index 32efb1d..45cc085 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -93,17 +93,15 @@ class PodLauncher(LoggingMixin):
if e.status != 404:
raise
- def run_pod(
+ def start_pod(
self,
pod: V1Pod,
- startup_timeout: int = 120,
- get_logs: bool = True) -> Tuple[State, Optional[str]]:
+ startup_timeout: int = 120):
"""
Launches the pod synchronously and waits for completion.
:param pod:
:param startup_timeout: Timeout for startup of the pod (if pod is
pending for too long, fails task)
- :param get_logs: whether to query k8s for logs
:return:
"""
resp = self.run_pod_async(pod)
@@ -116,9 +114,15 @@ class PodLauncher(LoggingMixin):
time.sleep(1)
self.log.debug('Pod not yet started')
- return self._monitor_pod(pod, get_logs)
+ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State,
Optional[str]]:
+ """
+ Monitors a pod and returns the final state
- def _monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State,
Optional[str]]:
+ :param pod: pod spec that will be monitored
+ :type pod : V1Pod
+ :param get_logs: whether to read the logs locally
+ :return: Tuple[State, Optional[str]]
+ """
if get_logs:
logs = self.read_pod_logs(pod)
for line in logs:
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index c2ccd74..bc5679d 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -16,7 +16,7 @@
# under the License.
"""Executes task in a Kubernetes POD"""
import re
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
import kubernetes.client.models as k8s
@@ -74,6 +74,8 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
:param cluster_context: context that points to kubernetes cluster.
Ignored when in_cluster is True. If None, current-context is used.
:type cluster_context: str
+ :param reattach_on_restart: if the scheduler dies while the pod is
running, reattach and monitor
+ :type reattach_on_restart: bool
:param labels: labels to apply to the Pod.
:type labels: dict
:param startup_timeout_seconds: timeout in seconds to startup the pod.
@@ -156,6 +158,7 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
in_cluster: Optional[bool] = None,
cluster_context: Optional[str] = None,
labels: Optional[Dict] = None,
+ reattach_on_restart: bool = True,
startup_timeout_seconds: int = 120,
get_logs: bool = True,
image_pull_policy: str = 'IfNotPresent',
@@ -201,6 +204,7 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
self.secrets = secrets or []
self.in_cluster = in_cluster
self.cluster_context = cluster_context
+ self.reattach_on_restart = reattach_on_restart
self.get_logs = get_logs
self.image_pull_policy = image_pull_policy
self.node_selectors = node_selectors or {}
@@ -225,7 +229,31 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
self.pod_template_file = pod_template_file
self.name = self._set_name(name)
- def execute(self, context):
+ @staticmethod
+ def create_labels_for_pod(context) -> dict:
+ """
+ Generate labels for the pod to track the pod in case of Operator crash
+
+ :param context: task context provided by airflow DAG
+ :return: dict
+ """
+ labels = {
+ 'dag_id': context['dag'].dag_id,
+ 'task_id': context['task'].task_id,
+ 'execution_date': context['ts'],
+ 'try_number': context['ti'].try_number,
+ }
+ # In the case of sub dags this is just useful
+ if context['dag'].is_subdag:
+ labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
+ # Ensure that label is valid for Kube,
+ # and if not truncate/remove invalid chars and replace with short hash.
+ for label_id, label in labels.items():
+ safe_label = pod_generator.make_safe_label_value(str(label))
+ labels[label_id] = safe_label
+ return labels
+
+ def execute(self, context) -> Optional[str]:
try:
if self.in_cluster is not None:
client =
kube_client.get_kube_client(in_cluster=self.in_cluster,
@@ -235,85 +263,49 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
client =
kube_client.get_kube_client(cluster_context=self.cluster_context,
config_file=self.config_file)
- if not (self.full_pod_spec or self.pod_template_file):
- # Add Airflow Version to the label
- # And a label to identify that pod is launched by
KubernetesPodOperator
- self.labels.update(
- {
- 'airflow_version': airflow_version.replace('+', '-'),
- 'kubernetes_pod_operator': 'True',
- }
- )
- pod = pod_generator.PodGenerator(
- image=self.image,
- namespace=self.namespace,
- cmds=self.cmds,
- args=self.arguments,
- labels=self.labels,
- name=self.name,
- envs=self.env_vars,
- extract_xcom=self.do_xcom_push,
- image_pull_policy=self.image_pull_policy,
- node_selectors=self.node_selectors,
- annotations=self.annotations,
- affinity=self.affinity,
- image_pull_secrets=self.image_pull_secrets,
- service_account_name=self.service_account_name,
- hostnetwork=self.hostnetwork,
- tolerations=self.tolerations,
- configmaps=self.configmaps,
- security_context=self.security_context,
- dnspolicy=self.dnspolicy,
- schedulername=self.schedulername,
- init_containers=self.init_containers,
- restart_policy='Never',
- priority_class_name=self.priority_class_name,
- pod_template_file=self.pod_template_file,
- pod=self.full_pod_spec,
- ).gen_pod()
+ # Add combination of labels to uniquely identify a running pod
+ labels = self.create_labels_for_pod(context)
- pod = append_to_pod(
- pod,
- self.pod_runtime_info_envs +
- self.ports +
- self.resources +
- self.secrets +
- self.volumes +
- self.volume_mounts
- )
+ label_selector = self._get_pod_identifying_label_string(labels)
- self.pod = pod
+ pod_list = client.list_namespaced_pod(self.namespace,
label_selector=label_selector)
- launcher = pod_launcher.PodLauncher(kube_client=client,
- extract_xcom=self.do_xcom_push)
+ if len(pod_list.items) > 1:
+ raise AirflowException(
+ 'More than one pod running with labels: '
+ '{label_selector}'.format(label_selector=label_selector))
- try:
- (final_state, result) = launcher.run_pod(
- pod,
- startup_timeout=self.startup_timeout_seconds,
- get_logs=self.get_logs)
- except AirflowException:
- if self.log_events_on_failure:
- for event in launcher.read_pod_events(pod).items:
- self.log.error("Pod Event: %s - %s", event.reason,
event.message)
- raise
- finally:
- if self.is_delete_operator_pod:
- launcher.delete_pod(pod)
+ launcher = pod_launcher.PodLauncher(kube_client=client,
extract_xcom=self.do_xcom_push)
+ if len(pod_list.items) == 1 and \
+ self._try_numbers_do_not_match(context, pod_list.items[0])
and \
+ self.reattach_on_restart:
+ self.log.info("found a running pod with labels %s but a
different try_number"
+ "Will attach to this pod and monitor instead of
starting new one", labels)
+ final_state, _, result =
self.create_new_pod_for_operator(labels, launcher)
+ elif len(pod_list.items) == 1:
+ self.log.info("found a running pod with labels %s."
+ "Will monitor this pod instead of starting new
one", labels)
+ final_state, result = self.monitor_launched_pod(launcher,
pod_list[0])
+ else:
+ final_state, _, result =
self.create_new_pod_for_operator(labels, launcher)
if final_state != State.SUCCESS:
- if self.log_events_on_failure:
- for event in launcher.read_pod_events(pod).items:
- self.log.error("Pod Event: %s - %s", event.reason,
event.message)
raise AirflowException(
- 'Pod returned a failure: {state}'.format(state=final_state)
- )
-
+ 'Pod returned a failure:
{state}'.format(state=final_state))
return result
except AirflowException as ex:
raise AirflowException('Pod Launching failed:
{error}'.format(error=ex))
@staticmethod
+ def _get_pod_identifying_label_string(labels):
+ filtered_labels = {label_id: label for label_id, label in
labels.items() if label_id != 'try_number'}
+ return ','.join([label_id + '=' + label for label_id, label in
sorted(filtered_labels.items())])
+
+ @staticmethod
+ def _try_numbers_do_not_match(context, pod):
+ return pod.metadata.labels['try_number'] != context['ti'].try_number
+
+ @staticmethod
def _set_resources(resources):
if not resources:
return []
@@ -324,3 +316,99 @@ class KubernetesPodOperator(BaseOperator): # pylint:
disable=too-many-instance-
return None
validate_key(name, max_length=220)
return re.sub(r'[^a-z0-9.-]+', '-', name.lower())
+
+ def create_new_pod_for_operator(self, labels, launcher) -> Tuple[State,
k8s.V1Pod, Optional[str]]:
+ """
+ Creates a new pod and monitors for duration of task
+
+ @param labels: labels used to track pod
+ @param launcher: pod launcher that will manage launching and
monitoring pods
+ @return:
+ """
+ if not (self.full_pod_spec or self.pod_template_file):
+ # Add Airflow Version to the label
+ # And a label to identify that pod is launched by
KubernetesPodOperator
+ self.labels.update(
+ {
+ 'airflow_version': airflow_version.replace('+', '-'),
+ 'kubernetes_pod_operator': 'True',
+ }
+ )
+ self.labels.update(labels)
+ pod = pod_generator.PodGenerator(
+ image=self.image,
+ namespace=self.namespace,
+ cmds=self.cmds,
+ args=self.arguments,
+ labels=self.labels,
+ name=self.name,
+ envs=self.env_vars,
+ extract_xcom=self.do_xcom_push,
+ image_pull_policy=self.image_pull_policy,
+ node_selectors=self.node_selectors,
+ annotations=self.annotations,
+ affinity=self.affinity,
+ image_pull_secrets=self.image_pull_secrets,
+ service_account_name=self.service_account_name,
+ hostnetwork=self.hostnetwork,
+ tolerations=self.tolerations,
+ configmaps=self.configmaps,
+ security_context=self.security_context,
+ dnspolicy=self.dnspolicy,
+ schedulername=self.schedulername,
+ init_containers=self.init_containers,
+ restart_policy='Never',
+ priority_class_name=self.priority_class_name,
+ pod_template_file=self.pod_template_file,
+ pod=self.full_pod_spec,
+ ).gen_pod()
+
+ # noinspection PyTypeChecker
+ pod = append_to_pod(
+ pod,
+ self.pod_runtime_info_envs + # type: ignore
+ self.ports + # type: ignore
+ self.resources + # type: ignore
+ self.secrets + # type: ignore
+ self.volumes + # type: ignore
+ self.volume_mounts # type: ignore
+ )
+
+ self.pod = pod
+
+ try:
+ launcher.start_pod(
+ pod,
+ startup_timeout=self.startup_timeout_seconds)
+ final_state, result = launcher.monitor_pod(pod=pod,
get_logs=self.get_logs)
+ except AirflowException:
+ if self.log_events_on_failure:
+ for event in launcher.read_pod_events(pod).items:
+ self.log.error("Pod Event: %s - %s", event.reason,
event.message)
+ raise
+ finally:
+ if self.is_delete_operator_pod:
+ launcher.delete_pod(pod)
+ return final_state, pod, result
+
+ def monitor_launched_pod(self, launcher, pod) -> Tuple[State,
Optional[str]]:
+ """
+ Montitors a pod to completion that was created by a previous
KubernetesPodOperator
+
+ @param launcher: pod launcher that will manage launching and
monitoring pods
+ :param pod: podspec used to find pod using k8s API
+ :return:
+ """
+ try:
+ (final_state, result) = launcher.monitor_pod(pod,
get_logs=self.get_logs)
+ finally:
+ if self.is_delete_operator_pod:
+ launcher.delete_pod(pod)
+ if final_state != State.SUCCESS:
+ if self.log_events_on_failure:
+ for event in launcher.read_pod_events(pod).items:
+ self.log.error("Pod Event: %s - %s", event.reason,
event.message)
+ raise AirflowException(
+ 'Pod returned a failure: {state}'.format(state=final_state)
+ )
+ return final_state, result
diff --git a/tests/executors/test_kubernetes_executor.py
b/tests/executors/test_kubernetes_executor.py
index df2c4fc..f9002a1 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -33,6 +33,7 @@ try:
from airflow.executors.kubernetes_executor import
AirflowKubernetesScheduler
from airflow.executors.kubernetes_executor import KubernetesExecutor
from airflow.executors.kubernetes_executor import KubeConfig
+ from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.utils.state import State
except ImportError:
@@ -91,19 +92,19 @@ class TestAirflowKubernetesScheduler(unittest.TestCase):
def test_make_safe_label_value(self):
for dag_id, task_id in self._cases():
- safe_dag_id =
AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+ safe_dag_id = pod_generator.make_safe_label_value(dag_id)
self.assertTrue(self._is_safe_label_value(safe_dag_id))
- safe_task_id =
AirflowKubernetesScheduler._make_safe_label_value(task_id)
+ safe_task_id = pod_generator.make_safe_label_value(task_id)
self.assertTrue(self._is_safe_label_value(safe_task_id))
dag_id = "my_dag_id"
self.assertEqual(
dag_id,
- AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+ pod_generator.make_safe_label_value(dag_id)
)
dag_id = "my_dag_id_" + "a" * 64
self.assertEqual(
"my_dag_id_" + "a" * 43 + "-0ce114c45",
- AirflowKubernetesScheduler._make_safe_label_value(dag_id)
+ pod_generator.make_safe_label_value(dag_id)
)
@unittest.skipIf(AirflowKubernetesScheduler is None,
diff --git a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
index 2d6f0ac..98d47ac 100644
--- a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
+++ b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py
@@ -23,18 +23,22 @@ from unittest import mock
from unittest.mock import ANY
import kubernetes.client.models as k8s
+import pendulum
import pytest
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from airflow.exceptions import AirflowException
+from airflow.kubernetes import kube_client
from airflow.kubernetes.pod import Port
from airflow.kubernetes.pod_generator import PodDefaults
from airflow.kubernetes.pod_launcher import PodLauncher
from airflow.kubernetes.secret import Secret
from airflow.kubernetes.volume import Volume
from airflow.kubernetes.volume_mount import VolumeMount
+from airflow.models import DAG, TaskInstance
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import
KubernetesPodOperator
+from airflow.utils import timezone
from airflow.version import version as airflow_version
@@ -53,8 +57,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
'annotations': {},
'labels': {
'foo': 'bar', 'kubernetes_pod_operator': 'True',
- 'airflow_version': airflow_version.replace('+', '-')
- }
+ 'airflow_version': airflow_version.replace('+', '-'),
+ 'execution_date': '2016-01-01T0100000100-a2f50a31f',
+ 'dag_id': 'dag',
+ 'task_id': 'task',
+ 'try_number': '1'},
},
'spec': {
'affinity': {},
@@ -81,6 +88,23 @@ class TestKubernetesPodOperator(unittest.TestCase):
}
}
+ def tearDown(self) -> None:
+ client = kube_client.get_kube_client(in_cluster=False)
+ client.delete_collection_namespaced_pod(namespace="default")
+
+ def create_context(self, task):
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("Europe/Amsterdam")
+ execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ task_instance = TaskInstance(task=task,
+ execution_date=execution_date)
+ return {
+ "dag": dag,
+ "ts": execution_date.isoformat(),
+ "task": task,
+ "ti": task_instance,
+ }
+
def test_do_xcom_push_defaults_false(self):
new_config_path = '/tmp/kube_config'
old_config_path = os.path.expanduser('~/.kube/config')
@@ -111,19 +135,21 @@ class TestKubernetesPodOperator(unittest.TestCase):
cmds=["bash", "-cx"],
arguments=["echo 10"],
labels={"foo": "bar"},
- name="test",
+ name="test1",
task_id="task",
in_cluster=False,
do_xcom_push=False,
config_file=new_config_path,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.assertEqual(self.expected_pod, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_config_path(self, client_mock, launcher_mock):
+ def test_config_path(self, client_mock, monitor_mock, start_mock): #
pylint: disable=unused-argument
from airflow.utils.state import State
file_path = "/tmp/fake_file"
@@ -140,17 +166,20 @@ class TestKubernetesPodOperator(unittest.TestCase):
config_file=file_path,
cluster_context='default',
)
- launcher_mock.return_value = (State.SUCCESS, None)
- k.execute(None)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ client_mock.list_namespaced_pod.return_value = []
+ context = self.create_context(k)
+ k.execute(context=context)
client_mock.assert_called_once_with(
in_cluster=False,
cluster_context='default',
config_file=file_path,
)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_image_pull_secrets_correctly_set(self, mock_client,
launcher_mock):
+ def test_image_pull_secrets_correctly_set(self, mock_client, monitor_mock,
start_mock):
from airflow.utils.state import State
fake_pull_secrets = "fakeSecret"
@@ -167,17 +196,24 @@ class TestKubernetesPodOperator(unittest.TestCase):
image_pull_secrets=fake_pull_secrets,
cluster_context='default',
)
- launcher_mock.return_value = (State.SUCCESS, None)
- k.execute(None)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context=context)
self.assertEqual(
- launcher_mock.call_args[0][0].spec.image_pull_secrets,
+ start_mock.call_args[0][0].spec.image_pull_secrets,
[k8s.V1LocalObjectReference(name=fake_pull_secrets)]
)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.delete_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_pod_delete_even_on_launcher_error(self, mock_client,
delete_pod_mock, run_pod_mock):
+ def test_pod_delete_even_on_launcher_error(
+ self,
+ mock_client,
+ delete_pod_mock,
+ monitor_pod_mock,
+ start_pod_mock): # pylint: disable=unused-argument
k = KubernetesPodOperator(
namespace='default',
image="ubuntu:16.04",
@@ -191,9 +227,10 @@ class TestKubernetesPodOperator(unittest.TestCase):
cluster_context='default',
is_delete_operator_pod=True,
)
- run_pod_mock.side_effect = AirflowException('fake failure')
+ monitor_pod_mock.side_effect = AirflowException('fake failure')
with self.assertRaises(AirflowException):
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context=context)
assert delete_pod_mock.called
def test_working_pod(self):
@@ -208,9 +245,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
in_cluster=False,
do_xcom_push=False,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod, actual_pod)
+ self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+ self.assertEqual(self.expected_pod['metadata']['labels'],
actual_pod['metadata']['labels'])
def test_delete_operator_pod(self):
k = KubernetesPodOperator(
@@ -225,9 +264,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
is_delete_operator_pod=True,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
- self.assertEqual(self.expected_pod, actual_pod)
+ self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+ self.assertEqual(self.expected_pod['metadata']['labels'],
actual_pod['metadata']['labels'])
def test_pod_hostnetwork(self):
k = KubernetesPodOperator(
@@ -242,10 +283,12 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
hostnetwork=True,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
- self.assertEqual(self.expected_pod, actual_pod)
+ self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+ self.assertEqual(self.expected_pod['metadata']['labels'],
actual_pod['metadata']['labels'])
def test_pod_dnspolicy(self):
dns_policy = "ClusterFirstWithHostNet"
@@ -262,11 +305,13 @@ class TestKubernetesPodOperator(unittest.TestCase):
hostnetwork=True,
dnspolicy=dns_policy
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['hostNetwork'] = True
self.expected_pod['spec']['dnsPolicy'] = dns_policy
- self.assertEqual(self.expected_pod, actual_pod)
+ self.assertEqual(self.expected_pod['spec'], actual_pod['spec'])
+ self.assertEqual(self.expected_pod['metadata']['labels'],
actual_pod['metadata']['labels'])
def test_pod_schedulername(self):
scheduler_name = "default-scheduler"
@@ -282,7 +327,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
schedulername=scheduler_name
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['schedulerName'] = scheduler_name
self.assertEqual(self.expected_pod, actual_pod)
@@ -303,7 +349,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
node_selectors=node_selectors,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['nodeSelector'] = node_selectors
self.assertEqual(self.expected_pod, actual_pod)
@@ -329,7 +376,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
resources=resources,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['resources'] = {
'requests': {
@@ -376,7 +424,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
affinity=affinity,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['affinity'] = affinity
self.assertEqual(self.expected_pod, actual_pod)
@@ -396,7 +445,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
ports=[port],
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context=context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['ports'] = [{
'name': 'http',
@@ -432,7 +482,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
in_cluster=False,
do_xcom_push=False,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context=context)
mock_logger.info.assert_any_call(b"retrieved from mount\n")
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['args'] = args
@@ -467,7 +518,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
security_context=security_context,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
self.assertEqual(self.expected_pod, actual_pod)
@@ -491,7 +543,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
security_context=security_context,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
self.assertEqual(self.expected_pod, actual_pod)
@@ -515,7 +568,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
security_context=security_context,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['securityContext'] = security_context
self.assertEqual(self.expected_pod, actual_pod)
@@ -535,7 +589,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
startup_timeout_seconds=5,
)
with self.assertRaises(AirflowException):
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['image'] =
bad_image_name
self.assertEqual(self.expected_pod, actual_pod)
@@ -556,7 +611,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
service_account_name=bad_service_account_name,
)
with self.assertRaises(ApiException):
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['serviceAccountName'] =
bad_service_account_name
self.assertEqual(self.expected_pod, actual_pod)
@@ -578,7 +634,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
)
with self.assertRaises(AirflowException):
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['containers'][0]['args'] =
bad_internal_command
self.assertEqual(self.expected_pod, actual_pod)
@@ -597,7 +654,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
in_cluster=False,
do_xcom_push=True,
)
- self.assertEqual(k.execute(None), json.loads(return_value))
+ context = self.create_context(k)
+ self.assertEqual(k.execute(context), json.loads(return_value))
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
volume = self.api_client.sanitize_for_serialization(PodDefaults.VOLUME)
volume_mount =
self.api_client.sanitize_for_serialization(PodDefaults.VOLUME_MOUNT)
@@ -608,9 +666,10 @@ class TestKubernetesPodOperator(unittest.TestCase):
self.expected_pod['spec']['containers'].append(container)
self.assertEqual(self.expected_pod, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_configmaps(self, mock_client, mock_launcher):
+ def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
# GIVEN
from airflow.utils.state import State
@@ -629,18 +688,20 @@ class TestKubernetesPodOperator(unittest.TestCase):
configmaps=[configmap],
)
# THEN
- mock_launcher.return_value = (State.SUCCESS, None)
- k.execute(None)
+ mock_monitor.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context)
self.assertEqual(
- mock_launcher.call_args[0][0].spec.containers[0].env_from,
+ mock_start.call_args[0][0].spec.containers[0].env_from,
[k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(
name=configmap
))]
)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_secrets(self, mock_client, launcher_mock):
+ def test_envs_from_secrets(self, mock_client, monitor_mock, start_mock):
# GIVEN
from airflow.utils.state import State
secret_ref = 'secret_name'
@@ -659,10 +720,11 @@ class TestKubernetesPodOperator(unittest.TestCase):
do_xcom_push=False,
)
# THEN
- launcher_mock.return_value = (State.SUCCESS, None)
- k.execute(None)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context)
self.assertEqual(
- launcher_mock.call_args[0][0].spec.containers[0].env_from,
+ start_mock.call_args[0][0].spec.containers[0].env_from,
[k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(
name=secret_ref
))]
@@ -696,9 +758,9 @@ class TestKubernetesPodOperator(unittest.TestCase):
volume_config = {
'persistentVolumeClaim':
- {
- 'claimName': 'test-volume'
- }
+ {
+ 'claimName': 'test-volume'
+ }
}
volume = Volume(name='test-volume', configs=volume_config)
@@ -734,7 +796,8 @@ class TestKubernetesPodOperator(unittest.TestCase):
in_cluster=False,
do_xcom_push=False,
)
- k.execute(None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['initContainers'] = [expected_init_container]
self.expected_pod['spec']['volumes'] = [{
@@ -745,17 +808,23 @@ class TestKubernetesPodOperator(unittest.TestCase):
}]
self.assertEqual(self.expected_pod, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_pod_template_file(self, mock_client, launcher_mock):
+ def test_pod_template_file(
+ self,
+ mock_client,
+ monitor_mock,
+ start_mock): # pylint: disable=unused-argument
from airflow.utils.state import State
k = KubernetesPodOperator(
task_id='task',
pod_template_file='tests/kubernetes/pod.yaml',
do_xcom_push=True
)
- launcher_mock.return_value = (State.SUCCESS, None)
- k.execute(None)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.assertEqual({
'apiVersion': 'v1',
@@ -791,9 +860,14 @@ class TestKubernetesPodOperator(unittest.TestCase):
}
}, actual_pod)
- @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod")
+ @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod")
@mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_pod_priority_class_name(self, mock_client, launcher_mock):
+ def test_pod_priority_class_name(
+ self,
+ mock_client,
+ monitor_mock,
+ start_mock): # pylint: disable=unused-argument
"""Test ability to assign priorityClassName to pod
"""
@@ -813,8 +887,9 @@ class TestKubernetesPodOperator(unittest.TestCase):
priority_class_name=priority_class_name,
)
- launcher_mock.return_value = (State.SUCCESS, None)
- k.execute(None)
+ monitor_mock.return_value = (State.SUCCESS, None)
+ context = self.create_context(k)
+ k.execute(context)
actual_pod = self.api_client.sanitize_for_serialization(k.pod)
self.expected_pod['spec']['priorityClassName'] = priority_class_name
self.assertEqual(self.expected_pod, actual_pod)