This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 00203dbbd1fc372d3770f0ec858d95b4330a0cfa Author: Daniel Imberman <[email protected]> AuthorDate: Sun Sep 27 14:39:35 2020 -0700 Allow overrides for pod_template_file (#11162) * Allow overrides for pod_template_file A pod_template_file should be treated as a *template* not a steadfast rule. This PR ensures that users can override individual values set by the pod_template_file s.t. the same file can be used for multiple tasks. * fix podtemplatetest * fix name (cherry picked from commit a888198c27bcdbc4538c02360c308ffcaca182fa) --- .../contrib/operators/kubernetes_pod_operator.py | 33 ++++-- airflow/kubernetes/pod_generator.py | 48 --------- kubernetes_tests/test_kubernetes_pod_operator.py | 116 ++++++++++++++------- tests/kubernetes/test_pod_generator.py | 16 +-- 4 files changed, 108 insertions(+), 105 deletions(-) diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index cdf5076..7754fd7 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -20,14 +20,16 @@ import re import yaml from airflow.exceptions import AirflowException -from airflow.kubernetes import kube_client, pod_generator, pod_launcher from airflow.kubernetes.k8s_model import append_to_pod +from airflow.kubernetes import kube_client, pod_generator, pod_launcher from airflow.kubernetes.pod import Resources from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.helpers import validate_key from airflow.utils.state import State from airflow.version import version as airflow_version +from airflow.kubernetes.pod_generator import PodGenerator +from kubernetes.client import models as k8s class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance-attributes @@ -218,8 +220,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- self.annotations = annotations or {} self.affinity = affinity or {} self.resources = self._set_resources(resources) # noqa + self.k8s_resources = self.resources self.config_file = config_file - self.image_pull_secrets = image_pull_secrets + self.image_pull_secrets = image_pull_secrets or [] self.service_account_name = service_account_name self.is_delete_operator_pod = is_delete_operator_pod self.hostnetwork = hostnetwork @@ -272,6 +275,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- client = kube_client.get_kube_client(cluster_context=self.cluster_context, config_file=self.config_file) + self.pod = self.create_pod_request_obj() + self.namespace = self.pod.metadata.namespace + self.client = client # Add combination of labels to uniquely identify a running pod @@ -356,6 +362,11 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- Creates a V1Pod based on user parameters. Note that a `pod` or `pod_template_file` will supersede all other values. """ + if self.pod_template_file: + pod_template = pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file) + else: + pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name")) + pod = pod_generator.PodGenerator( image=self.image, namespace=self.namespace, @@ -373,15 +384,12 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- service_account_name=self.service_account_name, hostnetwork=self.hostnetwork, tolerations=self.tolerations, - configmaps=self.configmaps, security_context=self.security_context, dnspolicy=self.dnspolicy, init_containers=self.init_containers, restart_policy='Never', schedulername=self.schedulername, - pod_template_file=self.pod_template_file, priority_class_name=self.priority_class_name, - pod=self.full_pod_spec, ).gen_pod() # noinspection PyTypeChecker @@ -395,6 +403,17 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- self.volume_mounts # type: ignore ) + env_from = pod.spec.containers[0].env_from or [] + for configmap in self.configmaps: + env_from.append(k8s.V1EnvFromSource(config_map_ref=k8s.V1ConfigMapEnvSource(name=configmap))) + pod.spec.containers[0].env_from = env_from + + if self.full_pod_spec: + pod_template = PodGenerator.reconcile_pods(pod_template, self.full_pod_spec) + pod = PodGenerator.reconcile_pods(pod_template, pod) + + # if self.do_xcom_push: + # pod = PodGenerator.add_sidecar(pod) return pod def create_new_pod_for_operator(self, labels, launcher): @@ -435,9 +454,9 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- def monitor_launched_pod(self, launcher, pod): """ - Montitors a pod to completion that was created by a previous KubernetesPodOperator + Monitors a pod to completion that was created by a previous KubernetesPodOperator - @param launcher: pod launcher that will manage launching and monitoring pods + :param launcher: pod launcher that will manage launching and monitoring pods :param pod: podspec used to find pod using k8s API :return: """ diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index ed518d1..4fbfec1 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -24,11 +24,6 @@ is supported and no serialization need be written. import copy import hashlib import re -try: - from inspect import signature -except ImportError: - # Python 2.7 - from funcsigs import signature # type: ignore import os import uuid from functools import reduce @@ -203,7 +198,6 @@ class PodGenerator(object): pod_template_file=None, extract_xcom=False, ): - self.validate_pod_generator_args(locals()) if pod_template_file: self.ud_pod = self.deserialize_model_file(pod_template_file) @@ -556,48 +550,6 @@ class PodGenerator(object): # pylint: disable=protected-access return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod) - @staticmethod - def validate_pod_generator_args(given_args): - """ - :param given_args: The arguments passed to the PodGenerator constructor. - :type given_args: dict - :return: None - - Validate that if `pod` or `pod_template_file` are set that the user is not attempting - to configure the pod with the other arguments. - """ - pod_args = list(signature(PodGenerator).parameters.items()) - - def predicate(k, v): - """ - :param k: an arg to PodGenerator - :type k: string - :param v: the parameter of the given arg - :type v: inspect.Parameter - :return: bool - - returns True if the PodGenerator argument has no default arguments - or the default argument is None, and it is not one of the listed field - in `non_empty_fields`. - """ - non_empty_fields = { - 'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy', - 'restart_policy' - } - - return (v.default is None or v.default is v.empty) and k not in non_empty_fields - - args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]} - - if given_args['pod'] and given_args['pod_template_file']: - raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments") - if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']): - raise AirflowConfigException( - "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format( - list(args_without_defaults.keys()) - ) - ) - def merge_objects(base_obj, client_obj): """ diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 0335b58..7a8674a 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -17,10 +17,12 @@ # under the License. import json +import logging import os import shutil import sys import unittest +import textwrap import kubernetes.client.models as k8s import pendulum @@ -834,6 +836,24 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase): self.assertIsNotNone(result) self.assertDictEqual(result, {"hello": "world"}) + def test_pod_template_file_with_overrides_system(self): + fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml' + k = KubernetesPodOperator( + task_id="task" + self.get_current_task_name(), + labels={"foo": "bar", "fizz": "buzz"}, + env_vars={"env_name": "value"}, + in_cluster=False, + pod_template_file=fixture, + do_xcom_push=True + ) + + context = create_context(k) + result = k.execute(context) + self.assertIsNotNone(result) + self.assertEqual(k.pod.metadata.labels, {'fizz': 'buzz', 'foo': 'bar'}) + self.assertEqual(k.pod.spec.containers[0].env, [k8s.V1EnvVar(name="env_name", value="value")]) + self.assertDictEqual(result, {"hello": "world"}) + def test_init_container(self): # GIVEN volume_mounts = [k8s.V1VolumeMount( @@ -917,48 +937,72 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase): @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_template_file(self, mock_client, monitor_mock, start_mock): from airflow.utils.state import State + fixture = sys.path[0] + '/tests/kubernetes/pod.yaml' k = KubernetesPodOperator( task_id='task', - pod_template_file='tests/kubernetes/pod.yaml', + pod_template_file=fixture, do_xcom_push=True ) monitor_mock.return_value = (State.SUCCESS, None) - context = self.create_context(k) - k.execute(context) + context = create_context(k) + with self.assertLogs(k.log, level=logging.DEBUG) as cm: + k.execute(context) + expected_line = textwrap.dedent("""\ + DEBUG:airflow.task.operators:Starting pod: + api_version: v1 + kind: Pod + metadata: + annotations: {} + cluster_name: null + creation_timestamp: null + deletion_grace_period_seconds: null\ + """).strip() + self.assertTrue(any(line.startswith(expected_line) for line in cm.output)) + actual_pod = self.api_client.sanitize_for_serialization(k.pod) - self.assertEqual({ - 'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': {'name': mock.ANY, 'namespace': 'mem-example'}, - 'spec': { - 'volumes': [{'name': 'xcom', 'emptyDir': {}}], - 'containers': [{ - 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], - 'command': ['stress'], - 'image': 'apache/airflow:stress-2020.07.10-1.0.4', - 'name': 'memory-demo-ctr', - 'resources': { - 'limits': {'memory': '200Mi'}, - 'requests': {'memory': '100Mi'} - }, - 'volumeMounts': [{ - 'name': 'xcom', - 'mountPath': '/airflow/xcom' - }] - }, { - 'name': 'airflow-xcom-sidecar', - 'image': "alpine", - 'command': ['sh', '-c', PodDefaults.XCOM_CMD], - 'volumeMounts': [ - { - 'name': 'xcom', - 'mountPath': '/airflow/xcom' - } - ], - 'resources': {'requests': {'cpu': '1m'}}, - }], - } - }, actual_pod) + expected_dict = {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'annotations': {}, + 'labels': {}, + 'name': 'memory-demo', + 'namespace': 'mem-example'}, + 'spec': {'affinity': {}, + 'containers': [{'args': ['--vm', + '1', + '--vm-bytes', + '150M', + '--vm-hang', + '1'], + 'command': ['stress'], + 'env': [], + 'envFrom': [], + 'image': 'apache/airflow:stress-2020.07.10-1.0.4', + 'imagePullPolicy': 'IfNotPresent', + 'name': 'base', + 'ports': [], + 'resources': {'limits': {'memory': '200Mi'}, + 'requests': {'memory': '100Mi'}}, + 'volumeMounts': [{'mountPath': '/airflow/xcom', + 'name': 'xcom'}]}, + {'command': ['sh', + '-c', + 'trap "exit 0" INT; while true; do sleep ' + '30; done;'], + 'image': 'alpine', + 'name': 'airflow-xcom-sidecar', + 'resources': {'requests': {'cpu': '1m'}}, + 'volumeMounts': [{'mountPath': '/airflow/xcom', + 'name': 'xcom'}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'initContainers': [], + 'nodeSelector': {}, + 'restartPolicy': 'Never', + 'securityContext': {}, + 'serviceAccountName': 'default', + 'tolerations': [], + 'volumes': [{'emptyDir': {}, 'name': 'xcom'}]}} + self.assertEqual(expected_dict, actual_pod) @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index 5243673..0c9d722 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -16,12 +16,12 @@ # under the License. import unittest +import sys from tests.compat import mock import uuid import kubernetes.client.models as k8s from kubernetes.client import ApiClient -from airflow.exceptions import AirflowConfigException from airflow.kubernetes.k8s_model import append_to_pod from airflow.kubernetes.pod import Resources from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects @@ -1045,7 +1045,7 @@ class TestPodGenerator(unittest.TestCase): self.assertEqual(client_spec, res) def test_deserialize_model_file(self): - fixture = 'tests/kubernetes/pod.yaml' + fixture = sys.path[0] + '/tests/kubernetes/pod.yaml' result = PodGenerator.deserialize_model_file(fixture) sanitized_res = self.k8s_client.sanitize_for_serialization(result) self.assertEqual(sanitized_res, self.deserialize_result) @@ -1073,18 +1073,6 @@ spec: sanitized_res = self.k8s_client.sanitize_for_serialization(result) self.assertEqual(sanitized_res, self.deserialize_result) - def test_validate_pod_generator(self): - with self.assertRaises(AirflowConfigException): - PodGenerator(image='k', pod=k8s.V1Pod()) - with self.assertRaises(AirflowConfigException): - PodGenerator(pod=k8s.V1Pod(), pod_template_file='k') - with self.assertRaises(AirflowConfigException): - PodGenerator(image='k', pod_template_file='k') - - PodGenerator(image='k') - PodGenerator(pod_template_file='tests/kubernetes/pod.yaml') - PodGenerator(pod=k8s.V1Pod()) - def test_add_custom_label(self): from kubernetes.client import models as k8s
