This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch fix-more-pod-mutations-tests in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3d7b30d5a4f98e3e60731141e4ce42755948c7ce Author: Kaxil Naik <[email protected]> AuthorDate: Sat Aug 1 03:24:15 2020 +0100 Fix more PodMutationHook issues for backwards compatibility This PR/commit - Adds missing affinity from old POD - Adds comprehensive tests to check pod_mutation_hook works well with both new and old PODs with various configs like volume, volumeMounts, Ports, affinity, tolerations etc - Refactors various parts of k8s code --- airflow/contrib/kubernetes/pod.py | 74 ++++++++- airflow/kubernetes/pod.py | 2 +- airflow/kubernetes/pod_launcher.py | 40 +++-- airflow/kubernetes/pod_launcher_helper.py | 96 ----------- tests/kubernetes/test_pod_launcher.py | 81 +++++++++- tests/kubernetes/test_pod_launcher_helper.py | 98 ----------- tests/test_local_settings.py | 232 ++++++++++++++++++++++----- 7 files changed, 376 insertions(+), 247 deletions(-) diff --git a/airflow/contrib/kubernetes/pod.py b/airflow/contrib/kubernetes/pod.py index 0ab3616..f2ce056 100644 --- a/airflow/contrib/kubernetes/pod.py +++ b/airflow/contrib/kubernetes/pod.py @@ -19,7 +19,13 @@ import warnings # pylint: disable=unused-import -from airflow.kubernetes.pod import Port, Resources # noqa +from typing import List, Union + +from kubernetes.client import models as k8s + +from airflow.kubernetes.pod import Port, Resources # noqa +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.volume_mount import VolumeMount warnings.warn( "This module is deprecated. Please use `airflow.kubernetes.pod`.", @@ -154,6 +160,7 @@ class Pod(object): dns_policy=self.dnspolicy, host_network=self.hostnetwork, tolerations=self.tolerations, + affinity=self.affinity, security_context=self.security_context, ) @@ -161,11 +168,11 @@ class Pod(object): spec=spec, metadata=meta, ) - for port in self.ports: + for port in _extract_ports(self.ports): pod = port.attach_to_pod(pod) - for volume in self.volumes: + for volume in _extract_volumes(self.volumes): pod = volume.attach_to_pod(pod) - for volume_mount in self.volume_mounts: + for volume_mount in _extract_volume_mounts(self.volume_mounts): pod = volume_mount.attach_to_pod(pod) for secret in self.secrets: pod = secret.attach_to_pod(pod) @@ -182,3 +189,62 @@ class Pod(object): res['volumes'] = [volume.as_dict() for volume in res['volumes']] return res + + +def _extract_env_vars(env_vars): + result = {} + env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]] + for env_var in env_vars: + if isinstance(env_var, k8s.V1EnvVar): + env_var.to_dict() + result[env_var.get("name")] = env_var.get("value") + return result + + +def _extract_ports(ports): + result = [] + ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]] + for port in ports: + if isinstance(port, k8s.V1ContainerPort): + port = port.to_dict() + port = Port(name=port.get("name"), container_port=port.get("container_port")) + if not isinstance(port, Port): + port = Port(name=port.get("name"), container_port=port.get("containerPort")) + result.append(port) + return result + + +def _extract_volume_mounts(volume_mounts): + result = [] + volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]] + for volume_mount in volume_mounts: + if isinstance(volume_mount, k8s.V1VolumeMount): + volume_mount = volume_mount.to_dict() + volume_mount = VolumeMount( + name=volume_mount.get("name"), + mount_path=volume_mount.get("mount_path"), + sub_path=volume_mount.get("sub_path"), + read_only=volume_mount.get("read_only") + ) + elif not isinstance(volume_mount, VolumeMount): + volume_mount = VolumeMount( + name=volume_mount.get("name"), + mount_path=volume_mount.get("mountPath"), + sub_path=volume_mount.get("subPath"), + read_only=volume_mount.get("readOnly") + ) + + result.append(volume_mount) + return result + + +def _extract_volumes(volumes): + result = [] + volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]] + for volume in volumes: + if isinstance(volume, k8s.V1Volume): + volume = volume.to_dict() + if not isinstance(volume, Volume): + volume = Volume(name=volume.get("name"), configs=volume) + result.append(volume) + return result diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py index 9e455af..c1854a1 100644 --- a/airflow/kubernetes/pod.py +++ b/airflow/kubernetes/pod.py @@ -20,7 +20,7 @@ Classes for interacting with Kubernetes API import copy -import kubernetes.client.models as k8s +from kubernetes.client import models as k8s from airflow.kubernetes.k8s_model import K8SModel diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index d6507df..dc75f8a 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -27,13 +27,15 @@ from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import BaseHTTPError from airflow import AirflowException -from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod from airflow.kubernetes.pod_generator import PodDefaults from airflow import settings from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State import kubernetes.client.models as k8s # noqa from .kube_client import get_kube_client +from ..contrib.kubernetes.pod import ( + Pod, _extract_env_vars, _extract_volumes, _extract_volume_mounts, _extract_ports +) class PodStatus: @@ -90,19 +92,17 @@ class PodLauncher(LoggingMixin): def _mutate_pod_backcompat(pod): """Backwards compatible Pod Mutation Hook""" try: - settings.pod_mutation_hook(pod) - # attempts to run pod_mutation_hook using k8s.V1Pod, if this - # fails we attempt to run by converting pod to Old Pod - except AttributeError: + dummy_pod = _convert_to_airflow_pod(pod) + settings.pod_mutation_hook(dummy_pod) warnings.warn( "Using `airflow.contrib.kubernetes.pod.Pod` is deprecated. " "Please use `k8s.V1Pod` instead.", DeprecationWarning, stacklevel=2 ) - dummy_pod = convert_to_airflow_pod(pod) - settings.pod_mutation_hook(dummy_pod) dummy_pod = dummy_pod.to_v1_kubernetes_pod() - return dummy_pod - return pod + except AttributeError: + settings.pod_mutation_hook(pod) + return pod + return dummy_pod def delete_pod(self, pod): """Deletes POD""" @@ -269,7 +269,7 @@ class PodLauncher(LoggingMixin): return None def process_status(self, job_id, status): - """Process status infomration for the JOB""" + """Process status information for the JOB""" status = status.lower() if status == PodStatus.PENDING: return State.QUEUED @@ -284,3 +284,23 @@ class PodLauncher(LoggingMixin): else: self.log.info('Event: Invalid state %s on job %s', status, job_id) return State.FAILED + + +def _convert_to_airflow_pod(pod): + base_container = pod.spec.containers[0] # type: k8s.V1Container + + dummy_pod = Pod( + image=base_container.image, + envs=_extract_env_vars(base_container.env), + volumes=_extract_volumes(pod.spec.volumes), + volume_mounts=_extract_volume_mounts(base_container.volume_mounts), + labels=pod.metadata.labels, + name=pod.metadata.name, + namespace=pod.metadata.namespace, + image_pull_policy=base_container.image_pull_policy or 'IfNotPresent', + cmds=[], + ports=_extract_ports(base_container.ports), + tolerations=pod.spec.tolerations, + affinity=pod.spec.affinity + ) + return dummy_pod diff --git a/airflow/kubernetes/pod_launcher_helper.py b/airflow/kubernetes/pod_launcher_helper.py deleted file mode 100644 index 8c9fc6e..0000000 --- a/airflow/kubernetes/pod_launcher_helper.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import List, Union - -import kubernetes.client.models as k8s # noqa - -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.volume_mount import VolumeMount -from airflow.kubernetes.pod import Port -from airflow.contrib.kubernetes.pod import Pod - - -def convert_to_airflow_pod(pod): - base_container = pod.spec.containers[0] # type: k8s.V1Container - - dummy_pod = Pod( - image=base_container.image, - envs=_extract_env_vars(base_container.env), - volumes=_extract_volumes(pod.spec.volumes), - volume_mounts=_extract_volume_mounts(base_container.volume_mounts), - labels=pod.metadata.labels, - name=pod.metadata.name, - namespace=pod.metadata.namespace, - image_pull_policy=base_container.image_pull_policy or 'IfNotPresent', - cmds=[], - ports=_extract_ports(base_container.ports) - ) - return dummy_pod - - -def _extract_env_vars(env_vars): - """ - - :param env_vars: - :type env_vars: list - :return: result - :rtype: dict - """ - result = {} - env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]] - for env_var in env_vars: - if isinstance(env_var, k8s.V1EnvVar): - env_var.to_dict() - result[env_var.get("name")] = env_var.get("value") - return result - - -def _extract_volumes(volumes): - result = [] - volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]] - for volume in volumes: - if isinstance(volume, k8s.V1Volume): - volume = volume.to_dict() - result.append(Volume(name=volume.get("name"), configs=volume)) - return result - - -def _extract_volume_mounts(volume_mounts): - result = [] - volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]] - for volume_mount in volume_mounts: - if isinstance(volume_mount, k8s.V1VolumeMount): - volume_mount = volume_mount.to_dict() - result.append( - VolumeMount( - name=volume_mount.get("name"), - mount_path=volume_mount.get("mount_path"), - sub_path=volume_mount.get("sub_path"), - read_only=volume_mount.get("read_only")) - ) - - return result - - -def _extract_ports(ports): - result = [] - ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]] - for port in ports: - if isinstance(port, k8s.V1ContainerPort): - port = port.to_dict() - result.append(Port(name=port.get("name"), container_port=port.get("container_port"))) - return result diff --git a/tests/kubernetes/test_pod_launcher.py b/tests/kubernetes/test_pod_launcher.py index 09ba339..e7e9a44 100644 --- a/tests/kubernetes/test_pod_launcher.py +++ b/tests/kubernetes/test_pod_launcher.py @@ -16,11 +16,16 @@ # under the License. import unittest import mock +from kubernetes.client import models as k8s from requests.exceptions import BaseHTTPError from airflow import AirflowException -from airflow.kubernetes.pod_launcher import PodLauncher +from airflow.contrib.kubernetes.pod import Pod +from airflow.kubernetes.pod import Port +from airflow.kubernetes.pod_launcher import PodLauncher, _convert_to_airflow_pod +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.volume_mount import VolumeMount class TestPodLauncher(unittest.TestCase): @@ -162,3 +167,77 @@ class TestPodLauncher(unittest.TestCase): self.pod_launcher.read_pod, mock.sentinel ) + + +class TestPodLauncherHelper(unittest.TestCase): + def test_convert_to_airflow_pod(self): + input_pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo", + namespace="bar" + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="base", + command="foo", + image="myimage", + ports=[ + k8s.V1ContainerPort( + name="myport", + container_port=8080, + ) + ], + volume_mounts=[k8s.V1VolumeMount( + name="mymount", + mount_path="/tmp/mount", + read_only="True" + )] + ) + ], + volumes=[ + k8s.V1Volume( + name="myvolume" + ) + ] + ) + ) + result_pod = _convert_to_airflow_pod(input_pod) + + expected = Pod( + name="foo", + namespace="bar", + envs={}, + cmds=[], + image="myimage", + ports=[ + Port(name="myport", container_port=8080) + ], + volume_mounts=[VolumeMount( + name="mymount", + mount_path="/tmp/mount", + sub_path=None, + read_only="True" + )], + volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})] + ) + expected_dict = expected.as_dict() + result_dict = result_pod.as_dict() + parsed_configs = self.pull_out_volumes(result_dict) + result_dict['volumes'] = parsed_configs + self.maxDiff = None + + self.assertDictEqual(expected_dict, result_dict) + + @staticmethod + def pull_out_volumes(result_dict): + parsed_configs = [] + for volume in result_dict['volumes']: + vol = {'name': volume['name']} + confs = {} + for k, v in volume['configs'].items(): + if v and k[0] != '_': + confs[k] = v + vol['configs'] = confs + parsed_configs.append(vol) + return parsed_configs diff --git a/tests/kubernetes/test_pod_launcher_helper.py b/tests/kubernetes/test_pod_launcher_helper.py deleted file mode 100644 index 761d138..0000000 --- a/tests/kubernetes/test_pod_launcher_helper.py +++ /dev/null @@ -1,98 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import unittest - -from airflow.kubernetes.pod import Port -from airflow.kubernetes.volume_mount import VolumeMount -from airflow.kubernetes.volume import Volume -from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod -from airflow.contrib.kubernetes.pod import Pod -import kubernetes.client.models as k8s - - -class TestPodLauncherHelper(unittest.TestCase): - def test_convert_to_airflow_pod(self): - input_pod = k8s.V1Pod( - metadata=k8s.V1ObjectMeta( - name="foo", - namespace="bar" - ), - spec=k8s.V1PodSpec( - containers=[ - k8s.V1Container( - name="base", - command="foo", - image="myimage", - ports=[ - k8s.V1ContainerPort( - name="myport", - container_port=8080, - ) - ], - volume_mounts=[k8s.V1VolumeMount( - name="mymount", - mount_path="/tmp/mount", - read_only="True" - )] - ) - ], - volumes=[ - k8s.V1Volume( - name="myvolume" - ) - ] - ) - ) - result_pod = convert_to_airflow_pod(input_pod) - - expected = Pod( - name="foo", - namespace="bar", - envs={}, - cmds=[], - image="myimage", - ports=[ - Port(name="myport", container_port=8080) - ], - volume_mounts=[VolumeMount( - name="mymount", - mount_path="/tmp/mount", - sub_path=None, - read_only="True" - )], - volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})] - ) - expected_dict = expected.as_dict() - result_dict = result_pod.as_dict() - parsed_configs = self.pull_out_volumes(result_dict) - result_dict['volumes'] = parsed_configs - self.maxDiff = None - - self.assertDictEqual(expected_dict, result_dict) - - @staticmethod - def pull_out_volumes(result_dict): - parsed_configs = [] - for volume in result_dict['volumes']: - vol = {'name': volume['name']} - confs = {} - for k, v in volume['configs'].items(): - if v and k[0] != '_': - confs[k] = v - vol['configs'] = confs - parsed_configs.append(vol) - return parsed_configs diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py index 0e45ad8..4f1d0b0 100644 --- a/tests/test_local_settings.py +++ b/tests/test_local_settings.py @@ -22,7 +22,9 @@ import sys import tempfile import unittest from airflow.kubernetes import pod_generator -from tests.compat import MagicMock, Mock, call, patch +from tests.compat import MagicMock, Mock, mock, call, patch + +from kubernetes.client.api_client import ApiClient SETTINGS_FILE_POLICY = """ @@ -48,18 +50,71 @@ def pod_mutation_hook(pod): pod.namespace = 'airflow-tests' pod.image = 'my_image' pod.volumes.append(Volume(name="bar", configs={})) - pod.ports = [Port(container_port=8080)] + pod.ports = [Port(container_port=8080), {"containerPort": 8081}] pod.resources = Resources( request_memory="2G", request_cpu="200Mi", limit_gpu="200G" ) + secret_volume = { + "name": "airflow-secrets-mount", + "secret": { + "secretName": "airflow-test-secrets" + } + } + secret_volume_mount = { + "name": "airflow-secrets-mount", + "readOnly": True, + "mountPath": "/opt/airflow/secrets/" + } + + pod.volumes.append(secret_volume) + pod.volume_mounts.append(secret_volume_mount) + + pod.labels.update({"test_label": "test_value"}) + pod.envs.update({"TEST_USER": "ADMIN"}) + + pod.tolerations += [ + {"key": "dynamic-pods", "operator": "Equal", "value": "true", "effect": "NoSchedule"} + ] + pod.affinity.update( + {"nodeAffinity": + {"requiredDuringSchedulingIgnoredDuringExecution": + {"nodeSelectorTerms": + [{ + "matchExpressions": [ + {"key": "test/dynamic-pods", "operator": "In", "values": ["true"]} + ] + }] + } + } + } + ) """ SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """ def pod_mutation_hook(pod): - pod.spec.containers[0].image = "test-image" + from kubernetes.client import models as k8s + secret_volume = { + "name": "airflow-secrets-mount", + "secret": { + "secretName": "airflow-test-secrets" + } + } + secret_volume_mount = { + "name": "airflow-secrets-mount", + "readOnly": True, + "mountPath": "/opt/airflow/secrets/" + } + base_container = pod.spec.containers[0] + base_container.image = "test-image" + base_container.volume_mounts.append(secret_volume_mount) + base_container.env.extend([{'name': 'TEST_USER', 'value': 'ADMIN'}]) + base_container.ports.extend([{'containerPort': 8080}, k8s.V1ContainerPort(container_port=8081)]) + + pod.spec.volumes.append(secret_volume) + pod.metadata.namespace = 'airflow-tests' """ @@ -85,6 +140,7 @@ class LocalSettingsTest(unittest.TestCase): # Make sure that the configure_logging is not cached def setUp(self): self.old_modules = dict(sys.modules) + self.maxDiff = None def tearDown(self): # Remove any new modules imported during the test run. This lets us @@ -181,50 +237,101 @@ class LocalSettingsTest(unittest.TestCase): self.mock_kube_client = Mock() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) + self.api_client = ApiClient() pod = pod_generator.PodGenerator( image="foo", name="bar", namespace="baz", image_pull_policy="Never", cmds=["foo"], + tolerations=[ + {'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'} + ], volume_mounts=[ - {"name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"} + {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True} ], volumes=[{"name": "foo"}] ).gen_pod() - self.assertEqual(pod.metadata.namespace, "baz") - self.assertEqual(pod.spec.containers[0].image, "foo") - self.assertEqual(pod.spec.volumes, [{'name': 'foo'}]) - self.assertEqual(pod.spec.containers[0].ports, []) - self.assertEqual(pod.spec.containers[0].resources, None) + sanitized_pod_pre_mutation = self.api_client.sanitize_for_serialization(pod) + self.assertEqual( + sanitized_pod_pre_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'name': mock.ANY, + 'namespace': 'baz'}, + 'spec': {'containers': [{'args': [], + 'command': ['foo'], + 'env': [], + 'envFrom': [], + 'image': 'foo', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'ports': [], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'tolerations': [{'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'}], + 'volumes': [{'name': 'foo'}]}} + ) + # Apply Pod Mutation Hook pod = self.pod_launcher._mutate_pod_backcompat(pod) - self.assertEqual(pod.metadata.namespace, "airflow-tests") - self.assertEqual(pod.spec.containers[0].image, "my_image") - self.assertEqual(pod.spec.volumes, [{'name': 'foo'}, {'name': 'bar'}]) - self.maxDiff = None + sanitized_pod_post_mutation = self.api_client.sanitize_for_serialization(pod) self.assertEqual( - pod.spec.containers[0].ports[0].to_dict(), - { - "container_port": 8080, - "host_ip": None, - "host_port": None, - "name": None, - "protocol": None - } - ) - self.assertEqual( - pod.spec.containers[0].resources.to_dict(), - { - 'limits': { - 'cpu': None, - 'memory': None, - 'ephemeral-storage': None, - 'nvidia.com/gpu': '200G'}, - 'requests': {'cpu': '200Mi', 'ephemeral-storage': None, 'memory': '2G'} - } + sanitized_pod_post_mutation, + {'metadata': {'labels': {'test_label': 'test_value'}, + 'name': mock.ANY, + 'namespace': 'airflow-tests'}, + 'spec': {'affinity': {'nodeAffinity': {'requiredDuringSchedulingIgnoredDuringExecution': { + 'nodeSelectorTerms': [{'matchExpressions': [{'key': 'test/dynamic-pods', + 'operator': 'In', + 'values': ['true']}]}]}}}, + 'containers': [{'args': [], + 'command': [], + 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}], + 'image': 'my_image', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'ports': [{'containerPort': 8080}, + {'containerPort': 8081}], + 'resources': {'limits': {'cpu': None, + 'ephemeral-storage': None, + 'memory': None, + 'nvidia.com/gpu': '200G'}, + 'requests': {'cpu': '200Mi', + 'ephemeral-storage': None, + 'memory': '2G'}}, + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}, + {'mountPath': '/opt/airflow/secrets/', + 'name': 'airflow-secrets-mount', + 'readOnly': True}]}], + 'hostNetwork': False, + 'tolerations': [{'effect': 'NoSchedule', + 'key': 'static-pods', + 'operator': 'Equal', + 'value': 'true'}, + {'effect': 'NoSchedule', + 'key': 'dynamic-pods', + 'operator': 'Equal', + 'value': 'true'}], + 'volumes': [{'name': 'foo'}, + {'name': 'bar'}, + {'name': 'airflow-secrets-mount', + 'secret': {'secretName': 'airflow-test-secrets'}}]}} ) def test_pod_mutation_v1_pod(self): @@ -234,19 +341,70 @@ class LocalSettingsTest(unittest.TestCase): from airflow.kubernetes.pod_launcher import PodLauncher self.mock_kube_client = Mock() + self.api_client = ApiClient() self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) pod = pod_generator.PodGenerator( image="myimage", cmds=["foo"], - volume_mounts={ - "name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True" - }, + namespace="baz", + volume_mounts=[ + {"name": "foo", "mountPath": "/mnt", "subPath": "/", "readOnly": True} + ], volumes=[{"name": "foo"}] ).gen_pod() - self.assertEqual(pod.spec.containers[0].image, "myimage") + sanitized_pod_pre_mutation = self.api_client.sanitize_for_serialization(pod) + + self.assertEqual( + sanitized_pod_pre_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'namespace': 'baz'}, + 'spec': {'containers': [{'args': [], + 'command': ['foo'], + 'env': [], + 'envFrom': [], + 'image': 'myimage', + 'name': 'base', + 'ports': [], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{'name': 'foo'}]}} + ) + + # Apply Pod Mutation Hook pod = self.pod_launcher._mutate_pod_backcompat(pod) - self.assertEqual(pod.spec.containers[0].image, "test-image") + + sanitized_pod_post_mutation = self.api_client.sanitize_for_serialization(pod) + self.assertEqual( + sanitized_pod_post_mutation, + {'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'namespace': 'airflow-tests'}, + 'spec': {'containers': [{'args': [], + 'command': ['foo'], + 'env': [{'name': 'TEST_USER', 'value': 'ADMIN'}], + 'envFrom': [], + 'image': 'test-image', + 'name': 'base', + 'ports': [{'containerPort': 8080}, {'containerPort': 8081}], + 'volumeMounts': [{'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, + 'subPath': '/'}, + {'mountPath': '/opt/airflow/secrets/', + 'name': 'airflow-secrets-mount', + 'readOnly': True}]}], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'volumes': [{'name': 'foo'}, + {'name': 'airflow-secrets-mount', + 'secret': {'secretName': 'airflow-test-secrets'}}]}} + ) class TestStatsWithAllowList(unittest.TestCase):
