This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit bcd02ddb81a07026dcbbc5e5a4dc669a6483b59b Author: Daniel Imberman <[email protected]> AuthorDate: Thu Jul 30 11:40:23 2020 -0700 Fixes PodMutationHook for backwards compatibility (#9903) Co-authored-by: Daniel Imberman <[email protected]> Co-authored-by: Kaxil Naik <[email protected]> --- airflow/kubernetes/k8s_model.py | 16 +++ airflow/kubernetes/pod.py | 33 ++++-- airflow/kubernetes/pod_launcher.py | 26 +++- airflow/kubernetes/pod_launcher_helper.py | 96 +++++++++++++++ airflow/kubernetes/volume_mount.py | 1 + airflow/kubernetes_deprecated/__init__.py | 16 +++ airflow/kubernetes_deprecated/pod.py | 171 +++++++++++++++++++++++++++ docs/conf.py | 1 + tests/kubernetes/models/test_pod.py | 81 +++++++++++++ tests/kubernetes/test_pod_launcher_helper.py | 97 +++++++++++++++ tests/test_local_settings.py | 96 +++++++++++++++ 11 files changed, 619 insertions(+), 15 deletions(-) diff --git a/airflow/kubernetes/k8s_model.py b/airflow/kubernetes/k8s_model.py index 3fd2f9e..e10a946 100644 --- a/airflow/kubernetes/k8s_model.py +++ b/airflow/kubernetes/k8s_model.py @@ -29,6 +29,7 @@ else: class K8SModel(ABC): + """ These Airflow Kubernetes models are here for backwards compatibility reasons only. Ideally clients should use the kubernetes api @@ -39,6 +40,7 @@ class K8SModel(ABC): can be avoided. All of these models implement the `attach_to_pod` method so that they integrate with the kubernetes client. """ + @abc.abstractmethod def attach_to_pod(self, pod): """ @@ -47,9 +49,23 @@ class K8SModel(ABC): :return: The pod with the object attached """ + def as_dict(self): + res = {} + if hasattr(self, "__slots__"): + for s in self.__slots__: + if hasattr(self, s): + res[s] = getattr(self, s) + if hasattr(self, "__dict__"): + res_dict = self.__dict__.copy() + res_dict.update(res) + return res_dict + return res + def append_to_pod(pod, k8s_objects): """ + Attach Kubernetes objects to the given POD + :param pod: A pod to attach a list of Kubernetes objects to :type pod: kubernetes.client.models.V1Pod :param k8s_objects: a potential None list of K8SModels diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py index 0b332c2..9e455af 100644 --- a/airflow/kubernetes/pod.py +++ b/airflow/kubernetes/pod.py @@ -26,7 +26,13 @@ from airflow.kubernetes.k8s_model import K8SModel class Resources(K8SModel): - __slots__ = ('request_memory', 'request_cpu', 'limit_memory', 'limit_cpu', 'limit_gpu') + __slots__ = ('request_memory', + 'request_cpu', + 'limit_memory', + 'limit_cpu', + 'limit_gpu', + 'request_ephemeral_storage', + 'limit_ephemeral_storage') """ :param request_memory: requested memory @@ -44,15 +50,17 @@ class Resources(K8SModel): :param limit_ephemeral_storage: Limit for ephemeral storage :type limit_ephemeral_storage: float | str """ + def __init__( - self, - request_memory=None, - request_cpu=None, - request_ephemeral_storage=None, - limit_memory=None, - limit_cpu=None, - limit_gpu=None, - limit_ephemeral_storage=None): + self, + request_memory=None, + request_cpu=None, + request_ephemeral_storage=None, + limit_memory=None, + limit_cpu=None, + limit_gpu=None, + limit_ephemeral_storage=None + ): self.request_memory = request_memory self.request_cpu = request_cpu self.request_ephemeral_storage = request_ephemeral_storage @@ -104,9 +112,10 @@ class Port(K8SModel): __slots__ = ('name', 'container_port') def __init__( - self, - name=None, - container_port=None): + self, + name=None, + container_port=None + ): """Creates port""" self.name = name self.container_port = container_port diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py index d27a647..05df204 100644 --- a/airflow/kubernetes/pod_launcher.py +++ b/airflow/kubernetes/pod_launcher.py @@ -26,10 +26,12 @@ from kubernetes.stream import stream as kubernetes_stream from requests.exceptions import BaseHTTPError from airflow import AirflowException +from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod from airflow.kubernetes.pod_generator import PodDefaults -from airflow.settings import pod_mutation_hook +from airflow import settings from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State +import kubernetes.client.models as k8s # noqa from .kube_client import get_kube_client @@ -62,8 +64,12 @@ class PodLauncher(LoggingMixin): self.extract_xcom = extract_xcom def run_pod_async(self, pod, **kwargs): - """Runs POD asynchronously""" - pod_mutation_hook(pod) + """Runs POD asynchronously + + :param pod: Pod to run + :type pod: k8s.V1Pod + """ + pod = self._mutate_pod_backcompat(pod) sanitized_pod = self._client.api_client.sanitize_for_serialization(pod) json_pod = json.dumps(sanitized_pod, indent=2) @@ -79,6 +85,20 @@ class PodLauncher(LoggingMixin): raise e return resp + @staticmethod + def _mutate_pod_backcompat(pod): + """Backwards compatible Pod Mutation Hook""" + try: + settings.pod_mutation_hook(pod) + # attempts to run pod_mutation_hook using k8s.V1Pod, if this + # fails we attempt to run by converting pod to Old Pod + except AttributeError: + dummy_pod = convert_to_airflow_pod(pod) + settings.pod_mutation_hook(dummy_pod) + dummy_pod = dummy_pod.to_v1_kubernetes_pod() + return dummy_pod + return pod + def delete_pod(self, pod): """Deletes POD""" try: diff --git a/airflow/kubernetes/pod_launcher_helper.py b/airflow/kubernetes/pod_launcher_helper.py new file mode 100644 index 0000000..d8b2698 --- /dev/null +++ b/airflow/kubernetes/pod_launcher_helper.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Union + +import kubernetes.client.models as k8s # noqa + +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.volume_mount import VolumeMount +from airflow.kubernetes.pod import Port +from airflow.kubernetes_deprecated.pod import Pod + + +def convert_to_airflow_pod(pod): + base_container = pod.spec.containers[0] # type: k8s.V1Container + + dummy_pod = Pod( + image=base_container.image, + envs=_extract_env_vars(base_container.env), + volumes=_extract_volumes(pod.spec.volumes), + volume_mounts=_extract_volume_mounts(base_container.volume_mounts), + labels=pod.metadata.labels, + name=pod.metadata.name, + namespace=pod.metadata.namespace, + image_pull_policy=base_container.image_pull_policy or 'IfNotPresent', + cmds=[], + ports=_extract_ports(base_container.ports) + ) + return dummy_pod + + +def _extract_env_vars(env_vars): + """ + + :param env_vars: + :type env_vars: list + :return: result + :rtype: dict + """ + result = {} + env_vars = env_vars or [] # type: List[Union[k8s.V1EnvVar, dict]] + for env_var in env_vars: + if isinstance(env_var, k8s.V1EnvVar): + env_var.to_dict() + result[env_var.get("name")] = env_var.get("value") + return result + + +def _extract_volumes(volumes): + result = [] + volumes = volumes or [] # type: List[Union[k8s.V1Volume, dict]] + for volume in volumes: + if isinstance(volume, k8s.V1Volume): + volume = volume.to_dict() + result.append(Volume(name=volume.get("name"), configs=volume)) + return result + + +def _extract_volume_mounts(volume_mounts): + result = [] + volume_mounts = volume_mounts or [] # type: List[Union[k8s.V1VolumeMount, dict]] + for volume_mount in volume_mounts: + if isinstance(volume_mount, k8s.V1VolumeMount): + volume_mount = volume_mount.to_dict() + result.append( + VolumeMount( + name=volume_mount.get("name"), + mount_path=volume_mount.get("mount_path"), + sub_path=volume_mount.get("sub_path"), + read_only=volume_mount.get("read_only")) + ) + + return result + + +def _extract_ports(ports): + result = [] + ports = ports or [] # type: List[Union[k8s.V1ContainerPort, dict]] + for port in ports: + if isinstance(port, k8s.V1ContainerPort): + port = port.to_dict() + result.append(Port(name=port.get("name"), container_port=port.get("container_port"))) + return result diff --git a/airflow/kubernetes/volume_mount.py b/airflow/kubernetes/volume_mount.py index 0dbca5f..ab87ba9 100644 --- a/airflow/kubernetes/volume_mount.py +++ b/airflow/kubernetes/volume_mount.py @@ -24,6 +24,7 @@ from airflow.kubernetes.k8s_model import K8SModel class VolumeMount(K8SModel): + __slots__ = ('name', 'mount_path', 'sub_path', 'read_only') """ Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to running container. diff --git a/airflow/kubernetes_deprecated/__init__.py b/airflow/kubernetes_deprecated/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/airflow/kubernetes_deprecated/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/kubernetes_deprecated/pod.py b/airflow/kubernetes_deprecated/pod.py new file mode 100644 index 0000000..22a8c12 --- /dev/null +++ b/airflow/kubernetes_deprecated/pod.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import kubernetes.client.models as k8s +from airflow.kubernetes.pod import Resources + + +class Pod(object): + """ + Represents a kubernetes pod and manages execution of a single pod. + + :param image: The docker image + :type image: str + :param envs: A dict containing the environment variables + :type envs: dict + :param cmds: The command to be run on the pod + :type cmds: list[str] + :param secrets: Secrets to be launched to the pod + :type secrets: list[airflow.contrib.kubernetes.secret.Secret] + :param result: The result that will be returned to the operator after + successful execution of the pod + :type result: any + :param image_pull_policy: Specify a policy to cache or always pull an image + :type image_pull_policy: str + :param image_pull_secrets: Any image pull secrets to be given to the pod. + If more than one secret is required, provide a comma separated list: + secret_a,secret_b + :type image_pull_secrets: str + :param affinity: A dict containing a group of affinity scheduling rules + :type affinity: dict + :param hostnetwork: If True enable host networking on the pod + :type hostnetwork: bool + :param tolerations: A list of kubernetes tolerations + :type tolerations: list + :param security_context: A dict containing the security context for the pod + :type security_context: dict + :param configmaps: A list containing names of configmaps object + mounting env variables to the pod + :type configmaps: list[str] + :param pod_runtime_info_envs: environment variables about + pod runtime information (ip, namespace, nodeName, podName) + :type pod_runtime_info_envs: list[PodRuntimeEnv] + :param dnspolicy: Specify a dnspolicy for the pod + :type dnspolicy: str + """ + def __init__( + self, + image, + envs, + cmds, + args=None, + secrets=None, + labels=None, + node_selectors=None, + name=None, + ports=None, + volumes=None, + volume_mounts=None, + namespace='default', + result=None, + image_pull_policy='IfNotPresent', + image_pull_secrets=None, + init_containers=None, + service_account_name=None, + resources=None, + annotations=None, + affinity=None, + hostnetwork=False, + tolerations=None, + security_context=None, + configmaps=None, + pod_runtime_info_envs=None, + dnspolicy=None + ): + self.image = image + self.envs = envs or {} + self.cmds = cmds + self.args = args or [] + self.secrets = secrets or [] + self.result = result + self.labels = labels or {} + self.name = name + self.ports = ports or [] + self.volumes = volumes or [] + self.volume_mounts = volume_mounts or [] + self.node_selectors = node_selectors or {} + self.namespace = namespace + self.image_pull_policy = image_pull_policy + self.image_pull_secrets = image_pull_secrets + self.init_containers = init_containers + self.service_account_name = service_account_name + self.resources = resources or Resources() + self.annotations = annotations or {} + self.affinity = affinity or {} + self.hostnetwork = hostnetwork or False + self.tolerations = tolerations or [] + self.security_context = security_context + self.configmaps = configmaps or [] + self.pod_runtime_info_envs = pod_runtime_info_envs or [] + self.dnspolicy = dnspolicy + + def to_v1_kubernetes_pod(self): + """ + Convert to support k8s V1Pod + + :return: k8s.V1Pod + """ + meta = k8s.V1ObjectMeta( + labels=self.labels, + name=self.name, + namespace=self.namespace, + ) + spec = k8s.V1PodSpec( + init_containers=self.init_containers, + containers=[ + k8s.V1Container( + image=self.image, + command=self.cmds, + name="base", + env=[k8s.V1EnvVar(name=key, value=val) for key, val in self.envs.items()], + args=self.args, + image_pull_policy=self.image_pull_policy, + ) + ], + image_pull_secrets=self.image_pull_secrets, + service_account_name=self.service_account_name, + dns_policy=self.dnspolicy, + host_network=self.hostnetwork, + tolerations=self.tolerations, + security_context=self.security_context, + ) + + pod = k8s.V1Pod( + spec=spec, + metadata=meta, + ) + for port in self.ports: + pod = port.attach_to_pod(pod) + for volume in self.volumes: + pod = volume.attach_to_pod(pod) + for volume_mount in self.volume_mounts: + pod = volume_mount.attach_to_pod(pod) + for secret in self.secrets: + pod = secret.attach_to_pod(pod) + for runtime_info in self.pod_runtime_info_envs: + pod = runtime_info.attach_to_pod(pod) + pod = self.resources.attach_to_pod(pod) + return pod + + def as_dict(self): + res = self.__dict__ + res['resources'] = res['resources'].as_dict() + res['ports'] = [port.as_dict() for port in res['ports']] + res['volume_mounts'] = [volume_mount.as_dict() for volume_mount in res['volume_mounts']] + res['volumes'] = [volume.as_dict() for volume in res['volumes']] + + return res diff --git a/docs/conf.py b/docs/conf.py index 6df66f8..d18b6ea 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -201,6 +201,7 @@ exclude_patterns = [ '_api/airflow/example_dags', '_api/airflow/index.rst', '_api/airflow/jobs', + '_api/airflow/kubernetes_deprecated', '_api/airflow/lineage', '_api/airflow/logging_config', '_api/airflow/macros', diff --git a/tests/kubernetes/models/test_pod.py b/tests/kubernetes/models/test_pod.py index 45c32aa..096b5f0 100644 --- a/tests/kubernetes/models/test_pod.py +++ b/tests/kubernetes/models/test_pod.py @@ -74,3 +74,84 @@ class TestPod(unittest.TestCase): 'volumes': [] } }, result) + + def test_to_v1_pod(self): + from airflow.kubernetes_deprecated.pod import Pod as DeprecatedPod + from airflow.kubernetes.volume import Volume + from airflow.kubernetes.volume_mount import VolumeMount + from airflow.kubernetes.pod import Resources + + pod = DeprecatedPod( + image="foo", + name="bar", + namespace="baz", + image_pull_policy="Never", + envs={"test_key": "test_value"}, + cmds=["airflow"], + resources=Resources( + request_memory="1G", + request_cpu="100Mi", + limit_gpu="100G" + ), + volumes=[Volume(name="foo", configs={})], + volume_mounts=[VolumeMount(name="foo", mount_path="/mnt", sub_path="/", read_only=True)] + ) + + k8s_client = ApiClient() + + result = pod.to_v1_kubernetes_pod() + result = k8s_client.sanitize_for_serialization(result) + + expected = \ + { + 'metadata': + { + 'labels': {}, + 'name': 'bar', + 'namespace': 'baz' + }, + 'spec': + {'containers': + [ + { + 'args': [], + 'command': ['airflow'], + 'env': [{'name': 'test_key', 'value': 'test_value'}], + 'image': 'foo', + 'imagePullPolicy': 'Never', + 'name': 'base', + 'volumeMounts': + [ + { + 'mountPath': '/mnt', + 'name': 'foo', + 'readOnly': True, 'subPath': '/' + } + ], # noqa + 'resources': + { + 'limits': + { + 'cpu': None, + 'memory': None, + 'nvidia.com/gpu': '100G', + 'ephemeral-storage': None + }, + 'requests': + { + 'cpu': '100Mi', + 'memory': '1G', + 'ephemeral-storage': None + } + } + } + ], + 'hostNetwork': False, + 'tolerations': [], + 'volumes': [ + {'name': 'foo'} + ] + } + } + self.maxDiff = None + self.assertEquals(expected, result) diff --git a/tests/kubernetes/test_pod_launcher_helper.py b/tests/kubernetes/test_pod_launcher_helper.py new file mode 100644 index 0000000..a308ac3 --- /dev/null +++ b/tests/kubernetes/test_pod_launcher_helper.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest + +from airflow.kubernetes.pod import Port +from airflow.kubernetes.volume_mount import VolumeMount +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.pod_launcher_helper import convert_to_airflow_pod +from airflow.kubernetes_deprecated.pod import Pod +import kubernetes.client.models as k8s + + +class TestPodLauncherHelper(unittest.TestCase): + def test_convert_to_airflow_pod(self): + input_pod = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name="foo", + namespace="bar" + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name="base", + command="foo", + image="myimage", + ports=[ + k8s.V1ContainerPort( + name="myport", + container_port=8080, + ) + ], + volume_mounts=[k8s.V1VolumeMount( + name="mymount", + mount_path="/tmp/mount", + read_only="True" + )] + ) + ], + volumes=[ + k8s.V1Volume( + name="myvolume" + ) + ] + ) + ) + result_pod = convert_to_airflow_pod(input_pod) + + expected = Pod( + name="foo", + namespace="bar", + envs={}, + cmds=[], + image="myimage", + ports=[ + Port(name="myport", container_port=8080) + ], + volume_mounts=[VolumeMount( + name="mymount", + mount_path="/tmp/mount", + sub_path=None, + read_only="True" + )], + volumes=[Volume(name="myvolume", configs={'name': 'myvolume'})] + ) + expected_dict = expected.as_dict() + result_dict = result_pod.as_dict() + parsed_configs = self.pull_out_volumes(result_dict) + result_dict['volumes'] = parsed_configs + self.maxDiff = None + + self.assertDictEqual(expected_dict, result_dict) + + def pull_out_volumes(self, result_dict): + parsed_configs = [] + for volume in result_dict['volumes']: + vol = {'name': volume['name']} + confs = {} + for k, v in volume['configs'].items(): + if v and k[0] != '_': + confs[k] = v + vol['configs'] = confs + parsed_configs.append(vol) + return parsed_configs diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py index 3497ee2..0e45ad8 100644 --- a/tests/test_local_settings.py +++ b/tests/test_local_settings.py @@ -21,6 +21,7 @@ import os import sys import tempfile import unittest +from airflow.kubernetes import pod_generator from tests.compat import MagicMock, Mock, call, patch @@ -40,8 +41,26 @@ def not_policy(): """ SETTINGS_FILE_POD_MUTATION_HOOK = """ +from airflow.kubernetes.volume import Volume +from airflow.kubernetes.pod import Port, Resources + def pod_mutation_hook(pod): pod.namespace = 'airflow-tests' + pod.image = 'my_image' + pod.volumes.append(Volume(name="bar", configs={})) + pod.ports = [Port(container_port=8080)] + pod.resources = Resources( + request_memory="2G", + request_cpu="200Mi", + limit_gpu="200G" + ) + +""" + +SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD = """ +def pod_mutation_hook(pod): + pod.spec.containers[0].image = "test-image" + """ @@ -148,9 +167,86 @@ class LocalSettingsTest(unittest.TestCase): settings.import_local_settings() # pylint: ignore pod = MagicMock() + pod.volumes = [] settings.pod_mutation_hook(pod) assert pod.namespace == 'airflow-tests' + self.assertEqual(pod.volumes[0].name, "bar") + + def test_pod_mutation_to_k8s_pod(self): + with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + from airflow.kubernetes.pod_launcher import PodLauncher + + self.mock_kube_client = Mock() + self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) + pod = pod_generator.PodGenerator( + image="foo", + name="bar", + namespace="baz", + image_pull_policy="Never", + cmds=["foo"], + volume_mounts=[ + {"name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True"} + ], + volumes=[{"name": "foo"}] + ).gen_pod() + + self.assertEqual(pod.metadata.namespace, "baz") + self.assertEqual(pod.spec.containers[0].image, "foo") + self.assertEqual(pod.spec.volumes, [{'name': 'foo'}]) + self.assertEqual(pod.spec.containers[0].ports, []) + self.assertEqual(pod.spec.containers[0].resources, None) + + pod = self.pod_launcher._mutate_pod_backcompat(pod) + + self.assertEqual(pod.metadata.namespace, "airflow-tests") + self.assertEqual(pod.spec.containers[0].image, "my_image") + self.assertEqual(pod.spec.volumes, [{'name': 'foo'}, {'name': 'bar'}]) + self.maxDiff = None + self.assertEqual( + pod.spec.containers[0].ports[0].to_dict(), + { + "container_port": 8080, + "host_ip": None, + "host_port": None, + "name": None, + "protocol": None + } + ) + self.assertEqual( + pod.spec.containers[0].resources.to_dict(), + { + 'limits': { + 'cpu': None, + 'memory': None, + 'ephemeral-storage': None, + 'nvidia.com/gpu': '200G'}, + 'requests': {'cpu': '200Mi', 'ephemeral-storage': None, 'memory': '2G'} + } + ) + + def test_pod_mutation_v1_pod(self): + with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK_V1_POD, "airflow_local_settings"): + from airflow import settings + settings.import_local_settings() # pylint: ignore + from airflow.kubernetes.pod_launcher import PodLauncher + + self.mock_kube_client = Mock() + self.pod_launcher = PodLauncher(kube_client=self.mock_kube_client) + pod = pod_generator.PodGenerator( + image="myimage", + cmds=["foo"], + volume_mounts={ + "name": "foo", "mount_path": "/mnt", "sub_path": "/", "read_only": "True" + }, + volumes=[{"name": "foo"}] + ).gen_pod() + + self.assertEqual(pod.spec.containers[0].image, "myimage") + pod = self.pod_launcher._mutate_pod_backcompat(pod) + self.assertEqual(pod.spec.containers[0].image, "test-image") class TestStatsWithAllowList(unittest.TestCase):
