This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b222e1717b8dd466bfb6880e0e079bc19f60e383 Author: Daniel Imberman <[email protected]> AuthorDate: Fri Jun 26 17:55:00 2020 -0700 [AIRFLOW-5413] Allow K8S worker pod to be configured from JSON/YAML file (#6230) * [AIRFLOW-5413] enable pod config from file * Update airflow/kubernetes/pod_generator.py Co-Authored-By: Ash Berlin-Taylor <[email protected]> * Update airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py Co-Authored-By: Ash Berlin-Taylor <[email protected]> Co-authored-by: Ash Berlin-Taylor <[email protected]> (cherry picked from commit 967930c0cb6e2293f2a49e5c9add5aa1917f3527) --- airflow/config_templates/config.yml | 7 + airflow/config_templates/default_airflow.cfg | 3 + .../contrib/operators/kubernetes_pod_operator.py | 15 ++- .../example_kubernetes_executor_config.py | 6 +- airflow/executors/kubernetes_executor.py | 11 +- airflow/kubernetes/pod_generator.py | 144 ++++++++++++++------- airflow/kubernetes/worker_configuration.py | 20 ++- kubernetes_tests/test_kubernetes_pod_operator.py | 50 ++++++- tests/executors/test_kubernetes_executor.py | 2 +- tests/kubernetes/models/test_pod.py | 2 - tests/kubernetes/models/test_secret.py | 2 - tests/kubernetes/pod.yaml | 33 +++++ tests/kubernetes/test_pod_generator.py | 80 +++++++++--- tests/kubernetes/test_worker_configuration.py | 14 +- 14 files changed, 304 insertions(+), 85 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 9b63200..61491d8 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1736,6 +1736,13 @@ type: string example: ~ default: "" + - name: pod_template_file + description: | + Path to the YAML pod file. If set, all other kubernetes-related fields are ignored. + version_added: ~ + type: string + example: ~ + default: "" - name: worker_container_tag description: ~ version_added: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index ca9de12..2cc97e2 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -801,6 +801,9 @@ verify_certs = True [kubernetes] # The repository, tag and imagePullPolicy of the Kubernetes Image for the Worker to Run worker_container_repository = + +# Path to the YAML pod file. If set, all other kubernetes-related fields are ignored. +pod_template_file = worker_container_tag = worker_container_image_pull_policy = IfNotPresent diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index ce8f19c..41f0df3 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -130,8 +130,18 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- :type schedulername: str :param full_pod_spec: The complete podSpec :type full_pod_spec: kubernetes.client.models.V1Pod + :param init_containers: init container for the launched Pod + :type init_containers: list[kubernetes.client.models.V1Container] + :param log_events_on_failure: Log the pod's events if a failure occurs + :type log_events_on_failure: bool + :param do_xcom_push: If True, the content of the file + /airflow/xcom/return.json in the container will also be pushed to an + XCom when the container completes. + :type do_xcom_push: bool + :param pod_template_file: path to pod template file + :type pod_template_file: str """ - template_fields = ('cmds', 'arguments', 'env_vars', 'config_file') + template_fields = ('cmds', 'arguments', 'env_vars', 'config_file', 'pod_template_file') @apply_defaults def __init__(self, # pylint: disable=too-many-arguments,too-many-locals @@ -215,8 +225,8 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- self.full_pod_spec = full_pod_spec self.init_containers = init_containers or [] self.log_events_on_failure = log_events_on_failure - self.priority_class_name = priority_class_name self.pod_template_file = pod_template_file + self.priority_class_name = priority_class_name self.name = self._set_name(name) @staticmethod @@ -348,6 +358,7 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- init_containers=self.init_containers, restart_policy='Never', schedulername=self.schedulername, + pod_template_file=self.pod_template_file, priority_class_name=self.priority_class_name, pod=self.full_pod_spec, ).gen_pod() diff --git a/airflow/example_dags/example_kubernetes_executor_config.py b/airflow/example_dags/example_kubernetes_executor_config.py index d740956..2e4ba00 100644 --- a/airflow/example_dags/example_kubernetes_executor_config.py +++ b/airflow/example_dags/example_kubernetes_executor_config.py @@ -83,14 +83,14 @@ with DAG( } ) - # Test that we can run tasks as a normal user + # Test that we can add labels to pods third_task = PythonOperator( task_id="non_root_task", python_callable=print_stuff, executor_config={ "KubernetesExecutor": { - "securityContext": { - "runAsUser": 1000 + "labels": { + "release": "stable" } } } diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 98e3154..e014aa3 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -20,7 +20,6 @@ import json import multiprocessing import time from queue import Empty -from uuid import uuid4 import kubernetes from dateutil import parser @@ -72,6 +71,9 @@ class KubeConfig: ) self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {}) self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None + self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', + fallback=None) + self.kube_labels = configuration_dict.get('kubernetes_labels', {}) self.delete_worker_pods = conf.getboolean( self.kubernetes_section, 'delete_worker_pods') @@ -220,6 +222,8 @@ class KubeConfig: return int(val) def _validate(self): + if self.pod_template_file: + return # TODO: use XOR for dags_volume_claim and git_dags_folder_mount_point if not self.dags_volume_claim \ and not self.dags_volume_host \ @@ -498,10 +502,7 @@ class AirflowKubernetesScheduler(LoggingMixin): dag_id) safe_task_id = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( task_id) - safe_uuid = AirflowKubernetesScheduler._strip_unsafe_kubernetes_special_chars( - uuid4().hex) - return AirflowKubernetesScheduler._make_safe_pod_id(safe_dag_id, safe_task_id, - safe_uuid) + return safe_dag_id + safe_task_id @staticmethod def _label_safe_datestring_to_datetime(string): diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index 711b1a9..e46407b 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -24,10 +24,20 @@ is supported and no serialization need be written. import copy import hashlib import re +try: + from inspect import signature +except ImportError: + # Python 2.7 + from funcsigs import signature # type: ignore +import os import uuid +from functools import reduce import kubernetes.client.models as k8s +import yaml +from kubernetes.client.api_client import ApiClient +from airflow.exceptions import AirflowConfigException from airflow.version import version as airflow_version MAX_LABEL_LEN = 63 @@ -35,10 +45,14 @@ MAX_LABEL_LEN = 63 MAX_POD_ID_LEN = 253 -class PodDefaults: +class PodDefaults(object): """ Static defaults for the PodGenerator """ + + def __init__(self): + pass + XCOM_MOUNT_PATH = '/airflow/xcom' SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar' XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;' @@ -82,7 +96,7 @@ def make_safe_label_value(string): return safe_label -class PodGenerator: +class PodGenerator(object): """ Contains Kubernetes Airflow Worker configuration logic @@ -147,9 +161,11 @@ class PodGenerator: :param dnspolicy: Specify a dnspolicy for the pod :type dnspolicy: str :param schedulername: Specify a schedulername for the pod - :type schedulername: str - :param pod: The fully specified pod. - :type pod: kubernetes.client.models.V1Pod + :type schedulername: Optional[str] + :param pod: The fully specified pod. Mutually exclusive with `path_or_string` + :type pod: Optional[kubernetes.client.models.V1Pod] + :param pod_template_file: Path to YAML file. Mutually exclusive with `pod` + :type pod_template_file: Optional[str] :param extract_xcom: Whether to bring up a container for xcom :type extract_xcom: bool """ @@ -167,8 +183,8 @@ class PodGenerator: node_selectors=None, ports=None, volumes=None, - image_pull_policy='IfNotPresent', - restart_policy='Never', + image_pull_policy=None, + restart_policy=None, image_pull_secrets=None, init_containers=None, service_account_name=None, @@ -183,9 +199,16 @@ class PodGenerator: schedulername=None, priority_class_name=None, pod=None, + pod_template_file=None, extract_xcom=False, ): - self.ud_pod = pod + self.validate_pod_generator_args(locals()) + + if pod_template_file: + self.ud_pod = self.deserialize_model_file(pod_template_file) + else: + self.ud_pod = pod + self.pod = k8s.V1Pod() self.pod.api_version = 'v1' self.pod.kind = 'Pod' @@ -348,37 +371,7 @@ class PodGenerator: 'iam.cloud.google.com/service-account': gcp_service_account_key }) - pod_spec_generator = PodGenerator( - image=namespaced.get('image'), - envs=namespaced.get('env'), - cmds=namespaced.get('cmds'), - args=namespaced.get('args'), - labels=namespaced.get('labels'), - node_selectors=namespaced.get('node_selectors'), - name=namespaced.get('name'), - ports=namespaced.get('ports'), - volumes=namespaced.get('volumes'), - volume_mounts=namespaced.get('volume_mounts'), - namespace=namespaced.get('namespace'), - image_pull_policy=namespaced.get('image_pull_policy'), - restart_policy=namespaced.get('restart_policy'), - image_pull_secrets=namespaced.get('image_pull_secrets'), - init_containers=namespaced.get('init_containers'), - service_account_name=namespaced.get('service_account_name'), - resources=resources, - annotations=namespaced.get('annotations'), - affinity=namespaced.get('affinity'), - hostnetwork=namespaced.get('hostnetwork'), - tolerations=namespaced.get('tolerations'), - security_context=namespaced.get('security_context'), - configmaps=namespaced.get('configmaps'), - dnspolicy=namespaced.get('dnspolicy'), - schedulername=namespaced.get('schedulername'), - pod=namespaced.get('pod'), - extract_xcom=namespaced.get('extract_xcom'), - ) - - return pod_spec_generator.gen_pod() + return PodGenerator(**namespaced).gen_pod() @staticmethod def reconcile_pods(base_pod, client_pod): @@ -495,12 +488,73 @@ class PodGenerator: name=pod_id ).gen_pod() - # Reconcile the pod generated by the Operator and the Pod - # generated by the .cfg file - pod_with_executor_config = PodGenerator.reconcile_pods(worker_config, - kube_executor_config) - # Reconcile that pod with the dynamic fields. - return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod) + # Reconcile the pods starting with the first chronologically, + # Pod from the airflow.cfg -> Pod from executor_config arg -> Pod from the K8s executor + pod_list = [worker_config, kube_executor_config, dynamic_pod] + + return reduce(PodGenerator.reconcile_pods, pod_list) + + @staticmethod + def deserialize_model_file(path): + """ + :param path: Path to the file + :return: a kubernetes.client.models.V1Pod + + Unfortunately we need access to the private method + ``_ApiClient__deserialize_model`` from the kubernetes client. + This issue is tracked here; https://github.com/kubernetes-client/python/issues/977. + """ + api_client = ApiClient() + if os.path.exists(path): + with open(path) as stream: + pod = yaml.safe_load(stream) + else: + pod = yaml.safe_load(path) + + # pylint: disable=protected-access + return api_client._ApiClient__deserialize_model(pod, k8s.V1Pod) + + @staticmethod + def validate_pod_generator_args(given_args): + """ + :param given_args: The arguments passed to the PodGenerator constructor. + :type given_args: dict + :return: None + + Validate that if `pod` or `pod_template_file` are set that the user is not attempting + to configure the pod with the other arguments. + """ + pod_args = list(signature(PodGenerator).parameters.items()) + + def predicate(k, v): + """ + :param k: an arg to PodGenerator + :type k: string + :param v: the parameter of the given arg + :type v: inspect.Parameter + :return: bool + + returns True if the PodGenerator argument has no default arguments + or the default argument is None, and it is not one of the listed field + in `non_empty_fields`. + """ + non_empty_fields = { + 'pod', 'pod_template_file', 'extract_xcom', 'service_account_name', 'image_pull_policy', + 'restart_policy' + } + + return (v.default is None or v.default is v.empty) and k not in non_empty_fields + + args_without_defaults = {k: given_args[k] for k, v in pod_args if predicate(k, v) and given_args[k]} + + if given_args['pod'] and given_args['pod_template_file']: + raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments") + if args_without_defaults and (given_args['pod'] or given_args['pod_template_file']): + raise AirflowConfigException( + "Cannot configure pod and pass either `pod` or `pod_template_file`. Fields {} passed.".format( + list(args_without_defaults.keys()) + ) + ) def merge_objects(base_obj, client_obj): diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py index 9c35910..0357f9f 100644 --- a/airflow/kubernetes/worker_configuration.py +++ b/airflow/kubernetes/worker_configuration.py @@ -28,7 +28,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin class WorkerConfiguration(LoggingMixin): - """Contains Kubernetes Airflow Worker configuration logic""" + """ + Contains Kubernetes Airflow Worker configuration logic + + :param kube_config: the kubernetes configuration from airflow.cfg + :type kube_config: airflow.executors.kubernetes_executor.KubeConfig + """ dags_volume_name = 'airflow-dags' logs_volume_name = 'airflow-logs' @@ -424,9 +429,12 @@ class WorkerConfiguration(LoggingMixin): def as_pod(self): """Creates POD.""" - pod_generator = PodGenerator( + if self.kube_config.pod_template_file: + return PodGenerator(pod_template_file=self.kube_config.pod_template_file).gen_pod() + + pod = PodGenerator( image=self.kube_config.kube_image, - image_pull_policy=self.kube_config.kube_image_pull_policy, + image_pull_policy=self.kube_config.kube_image_pull_policy or 'IfNotPresent', image_pull_secrets=self.kube_config.image_pull_secrets, volumes=self._get_volumes(), volume_mounts=self._get_volume_mounts(), @@ -436,10 +444,10 @@ class WorkerConfiguration(LoggingMixin): tolerations=self.kube_config.kube_tolerations, envs=self._get_environment(), node_selectors=self.kube_config.kube_node_selectors, - service_account_name=self.kube_config.worker_service_account_name, - ) + service_account_name=self.kube_config.worker_service_account_name or 'default', + restart_policy='Never' + ).gen_pod() - pod = pod_generator.gen_pod() pod.spec.containers[0].env_from = pod.spec.containers[0].env_from or [] pod.spec.containers[0].env_from.extend(self._get_env_from()) pod.spec.security_context = self._get_security_context() diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index e20324b..b6cecda 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -827,7 +827,55 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase): @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") - @patch("airflow.kubernetes.kube_client.get_kube_client") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + def test_pod_template_file(self, mock_client, monitor_mock, start_mock): + from airflow.utils.state import State + k = KubernetesPodOperator( + task_id='task', + pod_template_file='tests/kubernetes/pod.yaml', + do_xcom_push=True + ) + monitor_mock.return_value = (State.SUCCESS, None) + context = self.create_context(k) + k.execute(context) + actual_pod = self.api_client.sanitize_for_serialization(k.pod) + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'name': mock.ANY, 'namespace': 'mem-example'}, + 'spec': { + 'volumes': [{'name': 'xcom', 'emptyDir': {}}], + 'containers': [{ + 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], + 'command': ['stress'], + 'image': 'polinux/stress', + 'name': 'memory-demo-ctr', + 'resources': { + 'limits': {'memory': '200Mi'}, + 'requests': {'memory': '100Mi'} + }, + 'volumeMounts': [{ + 'name': 'xcom', + 'mountPath': '/airflow/xcom' + }] + }, { + 'name': 'airflow-xcom-sidecar', + 'image': "alpine", + 'command': ['sh', '-c', PodDefaults.XCOM_CMD], + 'volumeMounts': [ + { + 'name': 'xcom', + 'mountPath': '/airflow/xcom' + } + ], + 'resources': {'requests': {'cpu': '1m'}}, + }], + } + }, actual_pod) + + @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.start_pod") + @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.monitor_pod") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") def test_pod_priority_class_name( self, mock_client, diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index bf7bc56..3dabb78 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -33,8 +33,8 @@ try: from airflow.configuration import conf # noqa: F401 from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig from airflow.executors.kubernetes_executor import KubernetesExecutor - from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator + from airflow.kubernetes import pod_generator from airflow.utils.state import State except ImportError: AirflowKubernetesScheduler = None # type: ignore diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py index b63af6d..45c32aa 100644 --- a/tests/kubernetes/models/test_pod.py +++ b/tests/kubernetes/models/test_pod.py @@ -59,7 +59,6 @@ class TestPod(unittest.TestCase): 'env': [], 'envFrom': [], 'image': 'airflow-worker:latest', - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [{ 'name': 'https', @@ -72,7 +71,6 @@ class TestPod(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [] } }, result) diff --git a/tests/kubernetes/models/test_secret.py b/tests/kubernetes/models/test_secret.py index 44ab8b3..e91ff68 100644 --- a/tests/kubernetes/models/test_secret.py +++ b/tests/kubernetes/models/test_secret.py @@ -100,7 +100,6 @@ class TestSecret(unittest.TestCase): }], 'envFrom': [{'secretRef': {'name': 'secret_a'}}], 'image': 'airflow-worker:latest', - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [], 'volumeMounts': [{ @@ -110,7 +109,6 @@ class TestSecret(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [{ 'name': 'secretvol' + str(static_uuid), 'secret': {'secretName': 'secret_b'} diff --git a/tests/kubernetes/pod.yaml b/tests/kubernetes/pod.yaml new file mode 100644 index 0000000..cd419ed --- /dev/null +++ b/tests/kubernetes/pod.yaml @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +apiVersion: v1 +kind: Pod +metadata: + name: memory-demo + namespace: mem-example +spec: + containers: + - name: memory-demo-ctr + image: polinux/stress + resources: + limits: + memory: "200Mi" + requests: + memory: "100Mi" + command: ["stress"] + args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"] diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index ce15a8b..7d39cdc 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -21,6 +21,7 @@ import uuid import kubernetes.client.models as k8s from kubernetes.client import ApiClient +from airflow.exceptions import AirflowConfigException from airflow.kubernetes.k8s_model import append_to_pod from airflow.kubernetes.pod import Resources from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects @@ -31,6 +32,24 @@ class TestPodGenerator(unittest.TestCase): def setUp(self): self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48') + self.deserialize_result = { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'name': 'memory-demo', 'namespace': 'mem-example'}, + 'spec': { + 'containers': [{ + 'args': ['--vm', '1', '--vm-bytes', '150M', '--vm-hang', '1'], + 'command': ['stress'], + 'image': 'polinux/stress', + 'name': 'memory-demo-ctr', + 'resources': { + 'limits': {'memory': '200Mi'}, + 'requests': {'memory': '100Mi'} + } + }] + } + } + self.envs = { 'ENVIRONMENT': 'prod', 'LOG_LEVEL': 'warning' @@ -77,7 +96,6 @@ class TestPodGenerator(unittest.TestCase): 'command': [ 'sh', '-c', 'echo Hello Kubernetes!' ], - 'imagePullPolicy': 'IfNotPresent', 'env': [{ 'name': 'ENVIRONMENT', 'value': 'prod' @@ -126,7 +144,6 @@ class TestPodGenerator(unittest.TestCase): 'readOnly': True }] }], - 'restartPolicy': 'Never', 'volumes': [{ 'name': 'secretvol' + str(self.static_uuid), 'secret': { @@ -172,7 +189,7 @@ class TestPodGenerator(unittest.TestCase): result_dict['spec']['containers'][0]['envFrom'].sort( key=lambda x: list(x.values())[0]['name'] ) - self.assertDictEqual(result_dict, self.expected) + self.assertDictEqual(self.expected, result_dict) @mock.patch('uuid.uuid4') def test_gen_pod_extract_xcom(self, mock_uuid): @@ -238,9 +255,6 @@ class TestPodGenerator(unittest.TestCase): "name": "example-kubernetes-test-volume", }, ], - "securityContext": { - "runAsUser": 1000 - } } }) result = self.k8s_client.sanitize_for_serialization(result) @@ -264,6 +278,7 @@ class TestPodGenerator(unittest.TestCase): 'name': 'example-kubernetes-test-volume' }], }], + 'hostNetwork': False, 'imagePullSecrets': [], 'volumes': [{ 'hostPath': {'path': '/tmp/'}, @@ -339,7 +354,7 @@ class TestPodGenerator(unittest.TestCase): result = PodGenerator.reconcile_pods(base_pod, mutator_pod) result = self.k8s_client.sanitize_for_serialization(result) - self.assertEqual(result, { + self.assertEqual({ 'apiVersion': 'v1', 'kind': 'Pod', 'metadata': {'name': 'name2-' + self.static_uuid.hex}, @@ -353,7 +368,6 @@ class TestPodGenerator(unittest.TestCase): ], 'envFrom': [], 'image': 'image1', - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [{ 'containerPort': 2118, @@ -369,7 +383,6 @@ class TestPodGenerator(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [{ 'hostPath': {'path': '/tmp/'}, 'name': 'example-kubernetes-test-volume1' @@ -378,7 +391,7 @@ class TestPodGenerator(unittest.TestCase): 'name': 'example-kubernetes-test-volume2' }] } - }) + }, result) @mock.patch('uuid.uuid4') def test_construct_pod_empty_worker_config(self, mock_uuid): @@ -424,7 +437,6 @@ class TestPodGenerator(unittest.TestCase): 'command': ['command'], 'env': [], 'envFrom': [], - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [], 'resources': { @@ -437,7 +449,6 @@ class TestPodGenerator(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [] } }, sanitized_result) @@ -486,7 +497,6 @@ class TestPodGenerator(unittest.TestCase): 'command': ['command'], 'env': [], 'envFrom': [], - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [], 'resources': { @@ -499,7 +509,6 @@ class TestPodGenerator(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [] } }, sanitized_result) @@ -573,7 +582,6 @@ class TestPodGenerator(unittest.TestCase): 'command': ['command'], 'env': [], 'envFrom': [], - 'imagePullPolicy': 'IfNotPresent', 'name': 'base', 'ports': [], 'resources': { @@ -587,7 +595,6 @@ class TestPodGenerator(unittest.TestCase): }], 'hostNetwork': False, 'imagePullSecrets': [], - 'restartPolicy': 'Never', 'volumes': [] } }, sanitized_result) @@ -721,3 +728,44 @@ class TestPodGenerator(unittest.TestCase): client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')] client_spec.active_deadline_seconds = 100 self.assertEqual(client_spec, res) + + def test_deserialize_model_file(self): + fixture = 'tests/kubernetes/pod.yaml' + result = PodGenerator.deserialize_model_file(fixture) + sanitized_res = self.k8s_client.sanitize_for_serialization(result) + self.assertEqual(sanitized_res, self.deserialize_result) + + def test_deserialize_model_string(self): + fixture = """ +apiVersion: v1 +kind: Pod +metadata: + name: memory-demo + namespace: mem-example +spec: + containers: + - name: memory-demo-ctr + image: polinux/stress + resources: + limits: + memory: "200Mi" + requests: + memory: "100Mi" + command: ["stress"] + args: ["--vm", "1", "--vm-bytes", "150M", "--vm-hang", "1"] + """ + result = PodGenerator.deserialize_model_file(fixture) + sanitized_res = self.k8s_client.sanitize_for_serialization(result) + self.assertEqual(sanitized_res, self.deserialize_result) + + def test_validate_pod_generator(self): + with self.assertRaises(AirflowConfigException): + PodGenerator(image='k', pod=k8s.V1Pod()) + with self.assertRaises(AirflowConfigException): + PodGenerator(pod=k8s.V1Pod(), pod_template_file='k') + with self.assertRaises(AirflowConfigException): + PodGenerator(image='k', pod_template_file='k') + + PodGenerator(image='k') + PodGenerator(pod_template_file='tests/kubernetes/pod.yaml') + PodGenerator(pod=k8s.V1Pod()) diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py index 0730595..b6db2b5 100644 --- a/tests/kubernetes/test_worker_configuration.py +++ b/tests/kubernetes/test_worker_configuration.py @@ -17,7 +17,6 @@ # import unittest - import six from tests.compat import mock @@ -94,6 +93,9 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): self.kube_config.dags_folder = None self.kube_config.git_dags_folder_mount_point = None self.kube_config.kube_labels = {'dag_id': 'original_dag_id', 'my_label': 'label_id'} + self.kube_config.pod_template_file = '' + self.kube_config.restart_policy = '' + self.kube_config.image_pull_policy = '' self.api_client = ApiClient() @conf_vars({ @@ -358,7 +360,6 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): worker_config.as_pod(), "default", "sample-uuid", - ) expected_labels = { 'airflow-worker': 'sample-uuid', @@ -672,6 +673,15 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name='secretref_b')) ], configmaps) + def test_pod_template_file(self): + fixture = 'tests/kubernetes/pod.yaml' + self.kube_config.pod_template_file = fixture + worker_config = WorkerConfiguration(self.kube_config) + result = worker_config.as_pod() + expected = PodGenerator.deserialize_model_file(fixture) + expected.metadata.name = mock.ANY + self.assertEqual(expected, result) + def test_get_labels(self): worker_config = WorkerConfiguration(self.kube_config) labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, {
