This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0a853380fcc5dee7533c70bb0adfde0ebca7b420 Author: davlum <[email protected]> AuthorDate: Thu Jan 9 15:39:05 2020 -0500 [AIRFLOW-5413] Refactor worker config (#7114) (cherry picked from commit 51f262c65afd7eaecc54661a3b5c4e533feecff8) --- .github/workflows/ci.yml | 2 +- .../contrib/operators/kubernetes_pod_operator.py | 4 +- airflow/executors/kubernetes_executor.py | 13 +- airflow/kubernetes/pod_generator.py | 259 ++++++++-- airflow/kubernetes/worker_configuration.py | 18 +- tests/executors/test_kubernetes_executor.py | 100 ++-- tests/kubernetes/test_pod_generator.py | 541 ++++++++++++++++++--- tests/kubernetes/test_worker_configuration.py | 95 +++- 8 files changed, 838 insertions(+), 194 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb16aaf..67b9e50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -144,7 +144,7 @@ jobs: - name: Cache virtualenv for kubernetes testing uses: actions/cache@v2 env: - cache-name: cache-kubernetes-tests-virtualenv-v2 + cache-name: cache-kubernetes-tests-virtualenv-v3 with: path: .build/.kubernetes_venv key: "${{ env.cache-name }}-${{ github.job }}-\ diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index 8adb131..d439eda 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -367,11 +367,11 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- pod, startup_timeout=self.startup_timeout_seconds) final_state, result = launcher.monitor_pod(pod=pod, get_logs=self.get_logs) - except AirflowException: + except AirflowException as ex: if self.log_events_on_failure: for event in launcher.read_pod_events(pod).items: self.log.error("Pod Event: %s - %s", event.reason, event.message) - raise + raise AirflowException('Pod Launching failed: {error}'.format(error=ex)) finally: if self.is_delete_operator_pod: launcher.delete_pod(pod) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index d458d7a..74e504e 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -71,7 +71,7 @@ class KubeConfig: self.kubernetes_section, "worker_container_image_pull_policy" ) self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {}) - self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) + self.kube_annotations = configuration_dict.get('kubernetes_annotations', {}) or None self.kube_labels = configuration_dict.get('kubernetes_labels', {}) self.delete_worker_pods = conf.getboolean( self.kubernetes_section, 'delete_worker_pods') @@ -357,7 +357,7 @@ class AirflowKubernetesScheduler(LoggingMixin): self.log.debug("Kubernetes using namespace %s", self.namespace) self.kube_client = kube_client self.launcher = PodLauncher(kube_client=self.kube_client) - self.worker_configuration = WorkerConfiguration(kube_config=self.kube_config) + self.worker_configuration_pod = WorkerConfiguration(kube_config=self.kube_config).as_pod() self._manager = multiprocessing.Manager() self.watcher_queue = self._manager.Queue() self.worker_uuid = worker_uuid @@ -393,19 +393,20 @@ class AirflowKubernetesScheduler(LoggingMixin): if command[0:2] != ["airflow", "run"]: raise ValueError('The command must start with ["airflow", "run"].') - config_pod = self.worker_configuration.make_pod( + pod = PodGenerator.construct_pod( namespace=self.namespace, worker_uuid=self.worker_uuid, pod_id=self._create_pod_id(dag_id, task_id), dag_id=pod_generator.make_safe_label_value(dag_id), task_id=pod_generator.make_safe_label_value(task_id), try_number=try_number, - execution_date=self._datetime_to_label_safe_datestring(execution_date), - airflow_command=command + date=self._datetime_to_label_safe_datestring(execution_date), + command=command, + kube_executor_config=kube_executor_config, + worker_config=self.worker_configuration_pod ) # Reconcile the pod generated by the Operator and the Pod # generated by the .cfg file - pod = PodGenerator.reconcile_pods(config_pod, kube_executor_config) self.log.debug("Kubernetes running for command %s", command) self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image) diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index a614f41..bf0cedf 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -28,7 +28,7 @@ import uuid import kubernetes.client.models as k8s -from airflow.executors import Executors +from airflow.version import version as airflow_version MAX_LABEL_LEN = 63 @@ -87,28 +87,59 @@ class PodGenerator: Contains Kubernetes Airflow Worker configuration logic Represents a kubernetes pod and manages execution of a single pod. + Any configuration that is container specific gets applied to + the first container in the list of containers. + + Parameters with a type of `kubernetes.client.models.*`/`k8s.*` can + often be replaced with their dictionary equivalent, for example the output of + `sanitize_for_serialization`. + :param image: The docker image - :type image: str + :type image: Optional[str] + :param name: name in the metadata section (not the container name) + :type name: Optional[str] + :param namespace: pod namespace + :type namespace: Optional[str] + :param volume_mounts: list of kubernetes volumes mounts + :type volume_mounts: Optional[List[Union[k8s.V1VolumeMount, dict]]] :param envs: A dict containing the environment variables - :type envs: Dict[str, str] - :param cmds: The command to be run on the pod - :type cmds: List[str] - :param secrets: Secrets to be launched to the pod - :type secrets: List[airflow.kubernetes.models.secret.Secret] + :type envs: Optional[Dict[str, str]] + :param cmds: The command to be run on the first container + :type cmds: Optional[List[str]] + :param args: The arguments to be run on the pod + :type args: Optional[List[str]] + :param labels: labels for the pod metadata + :type labels: Optional[Dict[str, str]] + :param node_selectors: node selectors for the pod + :type node_selectors: Optional[Dict[str, str]] + :param ports: list of ports. Applies to the first container. + :type ports: Optional[List[Union[k8s.V1ContainerPort, dict]]] + :param volumes: Volumes to be attached to the first container + :type volumes: Optional[List[Union[k8s.V1Volume, dict]]] :param image_pull_policy: Specify a policy to cache or always pull an image :type image_pull_policy: str + :param restart_policy: The restart policy of the pod + :type restart_policy: str :param image_pull_secrets: Any image pull secrets to be given to the pod. If more than one secret is required, provide a comma separated list: secret_a,secret_b :type image_pull_secrets: str + :param init_containers: A list of init containers + :type init_containers: Optional[List[k8s.V1Container]] + :param service_account_name: Identity for processes that run in a Pod + :type service_account_name: Optional[str] + :param resources: Resource requirements for the first containers + :type resources: Optional[Union[k8s.V1ResourceRequirements, dict]] + :param annotations: annotations for the pod + :type annotations: Optional[Dict[str, str]] :param affinity: A dict containing a group of affinity scheduling rules - :type affinity: dict + :type affinity: Optional[dict] :param hostnetwork: If True enable host networking on the pod :type hostnetwork: bool :param tolerations: A list of kubernetes tolerations - :type tolerations: list + :type tolerations: Optional[list] :param security_context: A dict containing the security context for the pod - :type security_context: dict + :type security_context: Optional[Union[k8s.V1PodSecurityContext, dict]] :param configmaps: Any configmap refs to envfrom. If more than one configmap is required, provide a comma separated list configmap_a,configmap_b @@ -117,11 +148,13 @@ class PodGenerator: :type dnspolicy: str :param pod: The fully specified pod. :type pod: kubernetes.client.models.V1Pod + :param extract_xcom: Whether to bring up a container for xcom + :type extract_xcom: bool """ def __init__( self, - image, + image=None, name=None, namespace=None, volume_mounts=None, @@ -225,10 +258,11 @@ class PodGenerator: result.metadata = self.metadata result.spec.containers = [self.container] + result.metadata.name = self.make_unique_pod_id(result.metadata.name) + if self.extract_xcom: result = self.add_sidecar(result) - result.metadata.name = self.make_unique_pod_id(result.metadata.name) return result @staticmethod @@ -252,8 +286,9 @@ class PodGenerator: @staticmethod def add_sidecar(pod): pod_cp = copy.deepcopy(pod) - + pod_cp.spec.volumes = pod.spec.volumes or [] pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME) + pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or [] pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT) pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER) @@ -262,7 +297,7 @@ class PodGenerator: @staticmethod def from_obj(obj): if obj is None: - return k8s.V1Pod() + return None if isinstance(obj, PodGenerator): return obj.gen_pod() @@ -272,7 +307,12 @@ class PodGenerator: 'Cannot convert a non-dictionary or non-PodGenerator ' 'object into a KubernetesExecutorConfig') - namespaced = obj.get(Executors.KubernetesExecutor, {}) + # We do not want to extract constant here from ExecutorLoader because it is just + # A name in dictionary rather than executor selection mechanism and it causes cyclic import + namespaced = obj.get("KubernetesExecutor", {}) + + if not namespaced: + return None resources = namespaced.get('resources') @@ -348,46 +388,159 @@ class PodGenerator: should be preserved from base, the volumes appended to and the other fields overwritten. """ + if client_pod is None: + return base_pod client_pod_cp = copy.deepcopy(client_pod) + client_pod_cp.spec = PodGenerator.reconcile_specs(base_pod.spec, client_pod_cp.spec) - def merge_objects(base_obj, client_obj): - for base_key in base_obj.to_dict().keys(): - base_val = getattr(base_obj, base_key, None) - if not getattr(client_obj, base_key, None) and base_val: - setattr(client_obj, base_key, base_val) - - def extend_object_field(base_obj, client_obj, field_name): - base_obj_field = getattr(base_obj, field_name, None) - client_obj_field = getattr(client_obj, field_name, None) - if not base_obj_field: - return - if not client_obj_field: - setattr(client_obj, field_name, base_obj_field) - return - appended_fields = base_obj_field + client_obj_field - setattr(client_obj, field_name, appended_fields) - - # Values at the pod and metadata should be overwritten where they exist, - # but certain values at the spec and container level must be conserved. - base_container = base_pod.spec.containers[0] - client_container = client_pod_cp.spec.containers[0] - - extend_object_field(base_container, client_container, 'volume_mounts') - extend_object_field(base_container, client_container, 'env') - extend_object_field(base_container, client_container, 'env_from') - extend_object_field(base_container, client_container, 'ports') - extend_object_field(base_container, client_container, 'volume_devices') - client_container.command = base_container.command - client_container.args = base_container.args - merge_objects(base_pod.spec.containers[0], client_pod_cp.spec.containers[0]) - # Just append any additional containers from the base pod - client_pod_cp.spec.containers.extend(base_pod.spec.containers[1:]) - - merge_objects(base_pod.metadata, client_pod_cp.metadata) - - extend_object_field(base_pod.spec, client_pod_cp.spec, 'volumes') - merge_objects(base_pod.spec, client_pod_cp.spec) - merge_objects(base_pod, client_pod_cp) + client_pod_cp.metadata = merge_objects(base_pod.metadata, client_pod_cp.metadata) + client_pod_cp = merge_objects(base_pod, client_pod_cp) return client_pod_cp + + @staticmethod + def reconcile_specs(base_spec, + client_spec): + """ + :param base_spec: has the base attributes which are overwritten if they exist + in the client_spec and remain if they do not exist in the client_spec + :type base_spec: k8s.V1PodSpec + :param client_spec: the spec that the client wants to create. + :type client_spec: k8s.V1PodSpec + :return: the merged specs + """ + if base_spec and not client_spec: + return base_spec + if not base_spec and client_spec: + return client_spec + elif client_spec and base_spec: + client_spec.containers = PodGenerator.reconcile_containers( + base_spec.containers, client_spec.containers + ) + merged_spec = extend_object_field(base_spec, client_spec, 'volumes') + return merge_objects(base_spec, merged_spec) + + return None + + @staticmethod + def reconcile_containers(base_containers, + client_containers): + """ + :param base_containers: has the base attributes which are overwritten if they exist + in the client_containers and remain if they do not exist in the client_containers + :type base_containers: List[k8s.V1Container] + :param client_containers: the containers that the client wants to create. + :type client_containers: List[k8s.V1Container] + :return: the merged containers + + The runs recursively over the list of containers. + """ + if not base_containers: + return client_containers + if not client_containers: + return base_containers + + client_container = client_containers[0] + base_container = base_containers[0] + client_container = extend_object_field(base_container, client_container, 'volume_mounts') + client_container = extend_object_field(base_container, client_container, 'env') + client_container = extend_object_field(base_container, client_container, 'env_from') + client_container = extend_object_field(base_container, client_container, 'ports') + client_container = extend_object_field(base_container, client_container, 'volume_devices') + client_container = merge_objects(base_container, client_container) + + return [client_container] + PodGenerator.reconcile_containers( + base_containers[1:], client_containers[1:] + ) + + @staticmethod + def construct_pod( + dag_id, + task_id, + pod_id, + try_number, + date, + command, + kube_executor_config, + worker_config, + namespace, + worker_uuid + ): + """ + Construct a pod by gathering and consolidating the configuration from 3 places: + - airflow.cfg + - executor_config + - dynamic arguments + """ + dynamic_pod = PodGenerator( + namespace=namespace, + image='', + labels={ + 'airflow-worker': worker_uuid, + 'dag_id': dag_id, + 'task_id': task_id, + 'execution_date': date, + 'try_number': str(try_number), + 'airflow_version': airflow_version.replace('+', '-'), + 'kubernetes_executor': 'True', + }, + cmds=command, + name=pod_id + ).gen_pod() + + # Reconcile the pod generated by the Operator and the Pod + # generated by the .cfg file + pod_with_executor_config = PodGenerator.reconcile_pods(worker_config, + kube_executor_config) + # Reconcile that pod with the dynamic fields. + return PodGenerator.reconcile_pods(pod_with_executor_config, dynamic_pod) + + +def merge_objects(base_obj, client_obj): + """ + :param base_obj: has the base attributes which are overwritten if they exist + in the client_obj and remain if they do not exist in the client_obj + :param client_obj: the object that the client wants to create. + :return: the merged objects + """ + if not base_obj: + return client_obj + if not client_obj: + return base_obj + + client_obj_cp = copy.deepcopy(client_obj) + + for base_key in base_obj.to_dict().keys(): + base_val = getattr(base_obj, base_key, None) + if not getattr(client_obj, base_key, None) and base_val: + setattr(client_obj_cp, base_key, base_val) + return client_obj_cp + + +def extend_object_field(base_obj, client_obj, field_name): + """ + :param base_obj: an object which has a property `field_name` that is a list + :param client_obj: an object which has a property `field_name` that is a list. + A copy of this object is returned with `field_name` modified + :param field_name: the name of the list field + :type field_name: str + :return: the client_obj with the property `field_name` being the two properties appended + """ + client_obj_cp = copy.deepcopy(client_obj) + base_obj_field = getattr(base_obj, field_name, None) + client_obj_field = getattr(client_obj, field_name, None) + + if (not isinstance(base_obj_field, list) and base_obj_field is not None) or \ + (not isinstance(client_obj_field, list) and client_obj_field is not None): + raise ValueError("The chosen field must be a list.") + + if not base_obj_field: + return client_obj_cp + if not client_obj_field: + setattr(client_obj_cp, field_name, base_obj_field) + return client_obj_cp + + appended_fields = base_obj_field + client_obj_field + setattr(client_obj_cp, field_name, appended_fields) + return client_obj_cp diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py index bed1ac2..3464e81 100644 --- a/airflow/kubernetes/worker_configuration.py +++ b/airflow/kubernetes/worker_configuration.py @@ -25,7 +25,6 @@ from airflow.kubernetes.k8s_model import append_to_pod from airflow.kubernetes.pod_generator import PodGenerator from airflow.kubernetes.secret import Secret from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.version import version as airflow_version class WorkerConfiguration(LoggingMixin): @@ -418,23 +417,12 @@ class WorkerConfiguration(LoggingMixin): return self.kube_config.git_dags_folder_mount_point - def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_date, - try_number, airflow_command): + def as_pod(self): + """Creates POD.""" pod_generator = PodGenerator( - namespace=namespace, - name=pod_id, image=self.kube_config.kube_image, image_pull_policy=self.kube_config.kube_image_pull_policy, - labels={ - 'airflow-worker': worker_uuid, - 'dag_id': dag_id, - 'task_id': task_id, - 'execution_date': execution_date, - 'try_number': str(try_number), - 'airflow_version': airflow_version.replace('+', '-'), - 'kubernetes_executor': 'True', - }, - cmds=airflow_command, + image_pull_secrets=self.kube_config.image_pull_secrets, volumes=self._get_volumes(), volume_mounts=self._get_volume_mounts(), init_containers=self._get_init_containers(), diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 993c47a..2b3ed17 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -26,12 +26,12 @@ from urllib3 import HTTPResponse from airflow.utils import timezone from tests.compat import mock - +from tests.test_utils.config import conf_vars try: from kubernetes.client.rest import ApiException from airflow import configuration # noqa: F401 from airflow.configuration import conf # noqa: F401 - from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler + from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubeConfig from airflow.executors.kubernetes_executor import KubernetesExecutor from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator @@ -124,6 +124,56 @@ class TestAirflowKubernetesScheduler(unittest.TestCase): self.assertEqual(datetime_obj, new_datetime_obj) +class TestKubeConfig(unittest.TestCase): + def setUp(self): + if AirflowKubernetesScheduler is None: + self.skipTest("kubernetes python package is not installed") + + @conf_vars({ + ('kubernetes', 'git_ssh_known_hosts_configmap_name'): 'airflow-configmap', + ('kubernetes', 'git_ssh_key_secret_name'): 'airflow-secrets', + ('kubernetes_annotations', "iam.com/role"): "role-arn", + ('kubernetes_annotations', "other/annotation"): "value" + }) + def test_kube_config_worker_annotations_properly_parsed(self): + annotations = KubeConfig().kube_annotations + self.assertEqual({'iam.com/role': 'role-arn', 'other/annotation': 'value'}, annotations) + + @conf_vars({ + ('kubernetes', 'git_ssh_known_hosts_configmap_name'): 'airflow-configmap', + ('kubernetes', 'git_ssh_key_secret_name'): 'airflow-secrets' + }) + def test_kube_config_no_worker_annotations(self): + annotations = KubeConfig().kube_annotations + self.assertIsNone(annotations) + + @conf_vars({ + ('kubernetes', 'git_repo'): 'foo', + ('kubernetes', 'git_branch'): 'foo', + ('kubernetes', 'git_dags_folder_mount_point'): 'foo', + ('kubernetes', 'git_sync_run_as_user'): '0', + }) + def test_kube_config_git_sync_run_as_user_root(self): + self.assertEqual(KubeConfig().git_sync_run_as_user, 0) + + @conf_vars({ + ('kubernetes', 'git_repo'): 'foo', + ('kubernetes', 'git_branch'): 'foo', + ('kubernetes', 'git_dags_folder_mount_point'): 'foo', + }) + def test_kube_config_git_sync_run_as_user_not_present(self): + self.assertEqual(KubeConfig().git_sync_run_as_user, 65533) + + @conf_vars({ + ('kubernetes', 'git_repo'): 'foo', + ('kubernetes', 'git_branch'): 'foo', + ('kubernetes', 'git_dags_folder_mount_point'): 'foo', + ('kubernetes', 'git_sync_run_as_user'): '', + }) + def test_kube_config_git_sync_run_as_user_empty_string(self): + self.assertEqual(KubeConfig().git_sync_run_as_user, '') + + class TestKubernetesExecutor(unittest.TestCase): """ Tests if an ApiException from the Kube Client will cause the task to @@ -136,44 +186,45 @@ class TestKubernetesExecutor(unittest.TestCase): @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') def test_run_next_exception(self, mock_get_kube_client, mock_kubernetes_job_watcher): # When a quota is exceeded this is the ApiException we get - r = HTTPResponse( + response = HTTPResponse( body='{"kind": "Status", "apiVersion": "v1", "metadata": {}, "status": "Failure", ' '"message": "pods \\"podname\\" is forbidden: exceeded quota: compute-resources, ' 'requested: limits.memory=4Gi, used: limits.memory=6508Mi, limited: limits.memory=10Gi", ' '"reason": "Forbidden", "details": {"name": "podname", "kind": "pods"}, "code": 403}') - r.status = 403 - r.reason = "Forbidden" + response.status = 403 + response.reason = "Forbidden" # A mock kube_client that throws errors when making a pod mock_kube_client = mock.patch('kubernetes.client.CoreV1Api', autospec=True) mock_kube_client.create_namespaced_pod = mock.MagicMock( - side_effect=ApiException(http_resp=r)) + side_effect=ApiException(http_resp=response)) mock_get_kube_client.return_value = mock_kube_client mock_api_client = mock.MagicMock() mock_api_client.sanitize_for_serialization.return_value = {} mock_kube_client.api_client = mock_api_client - kubernetesExecutor = KubernetesExecutor() - kubernetesExecutor.start() + kubernetes_executor = KubernetesExecutor() + kubernetes_executor.start() # Execute a task while the Api Throws errors try_number = 1 - kubernetesExecutor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), - command=['airflow', 'run', 'true', 'some_parameter'], - executor_config={}) - kubernetesExecutor.sync() - kubernetesExecutor.sync() + kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), + queue=None, + command=['airflow', 'run', 'command'], + executor_config={}) + kubernetes_executor.sync() + kubernetes_executor.sync() assert mock_kube_client.create_namespaced_pod.called - self.assertFalse(kubernetesExecutor.task_queue.empty()) + self.assertFalse(kubernetes_executor.task_queue.empty()) # Disable the ApiException mock_kube_client.create_namespaced_pod.side_effect = None # Execute the task without errors should empty the queue - kubernetesExecutor.sync() + kubernetes_executor.sync() assert mock_kube_client.create_namespaced_pod.called - self.assertTrue(kubernetesExecutor.task_queue.empty()) + self.assertTrue(kubernetes_executor.task_queue.empty()) @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.sync') @@ -187,22 +238,19 @@ class TestKubernetesExecutor(unittest.TestCase): mock.call('executor.running_tasks', mock.ANY)] mock_stats_gauge.assert_has_calls(calls) - @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') - def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher, mock_kube_config): + def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher): executor = KubernetesExecutor() executor.start() key = ('dag_id', 'task_id', 'ex_time', 'try_number1') executor._change_state(key, State.RUNNING, 'pod_id', 'default') self.assertTrue(executor.event_buffer[key] == State.RUNNING) - @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') - def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher, - mock_kube_config): + def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher): executor = KubernetesExecutor() executor.start() test_time = timezone.utcnow() @@ -211,12 +259,10 @@ class TestKubernetesExecutor(unittest.TestCase): self.assertTrue(executor.event_buffer[key] == State.SUCCESS) mock_delete_pod.assert_called_once_with('pod_id', 'default') - @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') - def test_change_state_failed(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher, - mock_kube_config): + def test_change_state_failed(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher): executor = KubernetesExecutor() executor.kube_config.delete_worker_pods = False executor.kube_config.delete_worker_pods_on_failure = False @@ -227,12 +273,11 @@ class TestKubernetesExecutor(unittest.TestCase): self.assertTrue(executor.event_buffer[key] == State.FAILED) mock_delete_pod.assert_not_called() - @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') def test_change_state_skip_pod_deletion(self, mock_delete_pod, mock_get_kube_client, - mock_kubernetes_job_watcher, mock_kube_config): + mock_kubernetes_job_watcher): test_time = timezone.utcnow() executor = KubernetesExecutor() executor.kube_config.delete_worker_pods = False @@ -243,12 +288,11 @@ class TestKubernetesExecutor(unittest.TestCase): self.assertTrue(executor.event_buffer[key] == State.SUCCESS) mock_delete_pod.assert_not_called() - @mock.patch('airflow.executors.kubernetes_executor.KubeConfig') @mock.patch('airflow.executors.kubernetes_executor.KubernetesJobWatcher') @mock.patch('airflow.executors.kubernetes_executor.get_kube_client') @mock.patch('airflow.executors.kubernetes_executor.AirflowKubernetesScheduler.delete_pod') def test_change_state_failed_pod_deletion(self, mock_delete_pod, mock_get_kube_client, - mock_kubernetes_job_watcher, mock_kube_config): + mock_kubernetes_job_watcher): executor = KubernetesExecutor() executor.kube_config.delete_worker_pods_on_failure = True diff --git a/tests/kubernetes/test_pod_generator.py b/tests/kubernetes/test_pod_generator.py index 30839e7..a9a3aa5 100644 --- a/tests/kubernetes/test_pod_generator.py +++ b/tests/kubernetes/test_pod_generator.py @@ -20,15 +20,17 @@ from tests.compat import mock import uuid import kubernetes.client.models as k8s from kubernetes.client import ApiClient -from airflow.kubernetes.secret import Secret -from airflow.kubernetes.pod_generator import PodGenerator, PodDefaults -from airflow.kubernetes.pod import Resources + from airflow.kubernetes.k8s_model import append_to_pod +from airflow.kubernetes.pod import Resources +from airflow.kubernetes.pod_generator import PodDefaults, PodGenerator, extend_object_field, merge_objects +from airflow.kubernetes.secret import Secret class TestPodGenerator(unittest.TestCase): def setUp(self): + self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48') self.envs = { 'ENVIRONMENT': 'prod', 'LOG_LEVEL': 'warning' @@ -41,9 +43,23 @@ class TestPodGenerator(unittest.TestCase): # This should produce a single secret mounted in env Secret('env', 'TARGET', 'secret_b', 'source_b'), ] + self.labels = { + 'airflow-worker': 'uuid', + 'dag_id': 'dag_id', + 'execution_date': 'date', + 'task_id': 'task_id', + 'try_number': '3', + 'airflow_version': mock.ANY, + 'kubernetes_executor': 'True' + } + self.metadata = { + 'labels': self.labels, + 'name': 'pod_id-' + self.static_uuid.hex, + 'namespace': 'namespace' + } + self.resources = Resources('1Gi', 1, '2Gi', 2, 1) self.k8s_client = ApiClient() - self.static_uuid = uuid.UUID('cf4a56d2-8101-4217-b027-2af6216feb48') self.expected = { 'apiVersion': 'v1', 'kind': 'Pod', @@ -171,9 +187,9 @@ class TestPodGenerator(unittest.TestCase): fs_group=2000, ), ports=[k8s.V1ContainerPort(name='foo', container_port=1234)], - configmaps=['configmap_a', 'configmap_b'] + configmaps=['configmap_a', 'configmap_b'], + extract_xcom=True ) - pod_generator.extract_xcom = True result = pod_generator.gen_pod() result = append_to_pod(result, self.secrets) result = self.resources.attach_to_pod(result) @@ -253,79 +269,452 @@ class TestPodGenerator(unittest.TestCase): } }, result) - def test_reconcile_pods(self): - with mock.patch('uuid.uuid4') as mock_uuid: - mock_uuid.return_value = self.static_uuid - base_pod = PodGenerator( - image='image1', - name='name1', - envs={'key1': 'val1'}, - cmds=['/bin/command1.sh', 'arg1'], - ports=k8s.V1ContainerPort(name='port', container_port=2118), - volumes=[{ - 'hostPath': {'path': '/tmp/'}, - 'name': 'example-kubernetes-test-volume1' - }], - volume_mounts=[{ - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume1' - }], - ).gen_pod() - - mutator_pod = PodGenerator( - envs={'key2': 'val2'}, - image='', - name='name2', - cmds=['/bin/command2.sh', 'arg2'], - volumes=[{ - 'hostPath': {'path': '/tmp/'}, - 'name': 'example-kubernetes-test-volume2' - }], - volume_mounts=[{ - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume2' - }] - ).gen_pod() - - result = PodGenerator.reconcile_pods(base_pod, mutator_pod) - result = self.k8s_client.sanitize_for_serialization(result) - self.assertEqual(result, { - 'apiVersion': 'v1', - 'kind': 'Pod', - 'metadata': {'name': 'name2-' + self.static_uuid.hex}, - 'spec': { - 'containers': [{ - 'args': [], - 'command': ['/bin/command1.sh', 'arg1'], - 'env': [ - {'name': 'key1', 'value': 'val1'}, - {'name': 'key2', 'value': 'val2'} - ], - 'envFrom': [], - 'image': 'image1', - 'imagePullPolicy': 'IfNotPresent', - 'name': 'base', - 'ports': { - 'containerPort': 2118, - 'name': 'port', - }, - 'volumeMounts': [{ - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume1' - }, { - 'mountPath': '/foo/', - 'name': 'example-kubernetes-test-volume2' - }] + @mock.patch('uuid.uuid4') + def test_reconcile_pods_empty_mutator_pod(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + base_pod = PodGenerator( + image='image1', + name='name1', + envs={'key1': 'val1'}, + cmds=['/bin/command1.sh', 'arg1'], + ports=[k8s.V1ContainerPort(name='port', container_port=2118)], + volumes=[{ + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume1' + }], + volume_mounts=[{ + 'mountPath': '/foo/', + 'name': 'example-kubernetes-test-volume1' + }], + ).gen_pod() + + mutator_pod = None + name = 'name1-' + self.static_uuid.hex + + base_pod.metadata.name = name + + result = PodGenerator.reconcile_pods(base_pod, mutator_pod) + self.assertEqual(base_pod, result) + + mutator_pod = k8s.V1Pod() + result = PodGenerator.reconcile_pods(base_pod, mutator_pod) + self.assertEqual(base_pod, result) + + @mock.patch('uuid.uuid4') + def test_reconcile_pods(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + base_pod = PodGenerator( + image='image1', + name='name1', + envs={'key1': 'val1'}, + cmds=['/bin/command1.sh', 'arg1'], + ports=[k8s.V1ContainerPort(name='port', container_port=2118)], + volumes=[{ + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume1' + }], + volume_mounts=[{ + 'mountPath': '/foo/', + 'name': 'example-kubernetes-test-volume1' + }], + ).gen_pod() + + mutator_pod = PodGenerator( + envs={'key2': 'val2'}, + image='', + name='name2', + cmds=['/bin/command2.sh', 'arg2'], + volumes=[{ + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume2' + }], + volume_mounts=[{ + 'mountPath': '/foo/', + 'name': 'example-kubernetes-test-volume2' + }] + ).gen_pod() + + result = PodGenerator.reconcile_pods(base_pod, mutator_pod) + result = self.k8s_client.sanitize_for_serialization(result) + self.assertEqual(result, { + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': {'name': 'name2-' + self.static_uuid.hex}, + 'spec': { + 'containers': [{ + 'args': [], + 'command': ['/bin/command2.sh', 'arg2'], + 'env': [ + {'name': 'key1', 'value': 'val1'}, + {'name': 'key2', 'value': 'val2'} + ], + 'envFrom': [], + 'image': 'image1', + 'imagePullPolicy': 'IfNotPresent', + 'name': 'base', + 'ports': [{ + 'containerPort': 2118, + 'name': 'port', }], - 'hostNetwork': False, - 'imagePullSecrets': [], - 'restartPolicy': 'Never', - 'volumes': [{ - 'hostPath': {'path': '/tmp/'}, + 'volumeMounts': [{ + 'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume1' }, { - 'hostPath': {'path': '/tmp/'}, + 'mountPath': '/foo/', 'name': 'example-kubernetes-test-volume2' }] + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'restartPolicy': 'Never', + 'volumes': [{ + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume1' + }, { + 'hostPath': {'path': '/tmp/'}, + 'name': 'example-kubernetes-test-volume2' + }] + } + }) + + @mock.patch('uuid.uuid4') + def test_construct_pod_empty_worker_config(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + executor_config = k8s.V1Pod( + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '1m', + 'memory': '1G' + } + ) + ) + ] + ) + ) + worker_config = k8s.V1Pod() + + result = PodGenerator.construct_pod( + 'dag_id', + 'task_id', + 'pod_id', + 3, + 'date', + ['command'], + executor_config, + worker_config, + 'namespace', + 'uuid', + ) + sanitized_result = self.k8s_client.sanitize_for_serialization(result) + + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': self.metadata, + 'spec': { + 'containers': [{ + 'args': [], + 'command': ['command'], + 'env': [], + 'envFrom': [], + 'imagePullPolicy': 'IfNotPresent', + 'name': 'base', + 'ports': [], + 'resources': { + 'limits': { + 'cpu': '1m', + 'memory': '1G' + } + }, + 'volumeMounts': [] + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'restartPolicy': 'Never', + 'volumes': [] + } + }, sanitized_result) + + @mock.patch('uuid.uuid4') + def test_construct_pod_empty_execuctor_config(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + worker_config = k8s.V1Pod( + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '1m', + 'memory': '1G' + } + ) + ) + ] + ) + ) + executor_config = None + + result = PodGenerator.construct_pod( + 'dag_id', + 'task_id', + 'pod_id', + 3, + 'date', + ['command'], + executor_config, + worker_config, + 'namespace', + 'uuid', + ) + sanitized_result = self.k8s_client.sanitize_for_serialization(result) + + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': self.metadata, + 'spec': { + 'containers': [{ + 'args': [], + 'command': ['command'], + 'env': [], + 'envFrom': [], + 'imagePullPolicy': 'IfNotPresent', + 'name': 'base', + 'ports': [], + 'resources': { + 'limits': { + 'cpu': '1m', + 'memory': '1G' + } + }, + 'volumeMounts': [] + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'restartPolicy': 'Never', + 'volumes': [] + } + }, sanitized_result) + + @mock.patch('uuid.uuid4') + def test_construct_pod(self, mock_uuid): + mock_uuid.return_value = self.static_uuid + worker_config = k8s.V1Pod( + metadata=k8s.V1ObjectMeta( + name='gets-overridden-by-dynamic-args', + annotations={ + 'should': 'stay' } - }) + ), + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='doesnt-override', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '1m', + 'memory': '1G' + } + ), + security_context=k8s.V1SecurityContext( + run_as_user=1 + ) + ) + ] + ) + ) + executor_config = k8s.V1Pod( + spec=k8s.V1PodSpec( + containers=[ + k8s.V1Container( + name='doesnt-override-either', + resources=k8s.V1ResourceRequirements( + limits={ + 'cpu': '2m', + 'memory': '2G' + } + ) + ) + ] + ) + ) + + result = PodGenerator.construct_pod( + 'dag_id', + 'task_id', + 'pod_id', + 3, + 'date', + ['command'], + executor_config, + worker_config, + 'namespace', + 'uuid', + ) + sanitized_result = self.k8s_client.sanitize_for_serialization(result) + + self.metadata.update({'annotations': {'should': 'stay'}}) + + self.assertEqual({ + 'apiVersion': 'v1', + 'kind': 'Pod', + 'metadata': self.metadata, + 'spec': { + 'containers': [{ + 'args': [], + 'command': ['command'], + 'env': [], + 'envFrom': [], + 'imagePullPolicy': 'IfNotPresent', + 'name': 'base', + 'ports': [], + 'resources': { + 'limits': { + 'cpu': '2m', + 'memory': '2G' + } + }, + 'volumeMounts': [], + 'securityContext': {'runAsUser': 1} + }], + 'hostNetwork': False, + 'imagePullSecrets': [], + 'restartPolicy': 'Never', + 'volumes': [] + } + }, sanitized_result) + + def test_merge_objects_empty(self): + annotations = {'foo1': 'bar1'} + base_obj = k8s.V1ObjectMeta(annotations=annotations) + client_obj = None + res = merge_objects(base_obj, client_obj) + self.assertEqual(base_obj, res) + + client_obj = k8s.V1ObjectMeta() + res = merge_objects(base_obj, client_obj) + self.assertEqual(base_obj, res) + + client_obj = k8s.V1ObjectMeta(annotations=annotations) + base_obj = None + res = merge_objects(base_obj, client_obj) + self.assertEqual(client_obj, res) + + base_obj = k8s.V1ObjectMeta() + res = merge_objects(base_obj, client_obj) + self.assertEqual(client_obj, res) + + def test_merge_objects(self): + base_annotations = {'foo1': 'bar1'} + base_labels = {'foo1': 'bar1'} + client_annotations = {'foo2': 'bar2'} + base_obj = k8s.V1ObjectMeta( + annotations=base_annotations, + labels=base_labels + ) + client_obj = k8s.V1ObjectMeta(annotations=client_annotations) + res = merge_objects(base_obj, client_obj) + client_obj.labels = base_labels + self.assertEqual(client_obj, res) + + def test_extend_object_field_empty(self): + ports = [k8s.V1ContainerPort(container_port=1, name='port')] + base_obj = k8s.V1Container(name='base_container', ports=ports) + client_obj = k8s.V1Container(name='client_container') + res = extend_object_field(base_obj, client_obj, 'ports') + client_obj.ports = ports + self.assertEqual(client_obj, res) + + base_obj = k8s.V1Container(name='base_container') + client_obj = k8s.V1Container(name='base_container', ports=ports) + res = extend_object_field(base_obj, client_obj, 'ports') + self.assertEqual(client_obj, res) + + def test_extend_object_field_not_list(self): + base_obj = k8s.V1Container(name='base_container', image='image') + client_obj = k8s.V1Container(name='client_container') + with self.assertRaises(ValueError): + extend_object_field(base_obj, client_obj, 'image') + base_obj = k8s.V1Container(name='base_container') + client_obj = k8s.V1Container(name='client_container', image='image') + with self.assertRaises(ValueError): + extend_object_field(base_obj, client_obj, 'image') + + def test_extend_object_field(self): + base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')] + base_obj = k8s.V1Container(name='base_container', ports=base_ports) + client_ports = [k8s.V1ContainerPort(container_port=1, name='client_port')] + client_obj = k8s.V1Container(name='client_container', ports=client_ports) + res = extend_object_field(base_obj, client_obj, 'ports') + client_obj.ports = base_ports + client_ports + self.assertEqual(client_obj, res) + + def test_reconcile_containers_empty(self): + base_objs = [k8s.V1Container(name='base_container')] + client_objs = [] + res = PodGenerator.reconcile_containers(base_objs, client_objs) + self.assertEqual(base_objs, res) + + client_objs = [k8s.V1Container(name='client_container')] + base_objs = [] + res = PodGenerator.reconcile_containers(base_objs, client_objs) + self.assertEqual(client_objs, res) + + res = PodGenerator.reconcile_containers([], []) + self.assertEqual(res, []) + + def test_reconcile_containers(self): + base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')] + base_objs = [ + k8s.V1Container(name='base_container1', ports=base_ports), + k8s.V1Container(name='base_container2', image='base_image'), + ] + client_ports = [k8s.V1ContainerPort(container_port=2, name='client_port')] + client_objs = [ + k8s.V1Container(name='client_container1', ports=client_ports), + k8s.V1Container(name='client_container2', image='client_image'), + ] + res = PodGenerator.reconcile_containers(base_objs, client_objs) + client_objs[0].ports = base_ports + client_ports + self.assertEqual(client_objs, res) + + base_ports = [k8s.V1ContainerPort(container_port=1, name='base_port')] + base_objs = [ + k8s.V1Container(name='base_container1', ports=base_ports), + k8s.V1Container(name='base_container2', image='base_image'), + ] + client_ports = [k8s.V1ContainerPort(container_port=2, name='client_port')] + client_objs = [ + k8s.V1Container(name='client_container1', ports=client_ports), + k8s.V1Container(name='client_container2', stdin=True), + ] + res = PodGenerator.reconcile_containers(base_objs, client_objs) + client_objs[0].ports = base_ports + client_ports + client_objs[1].image = 'base_image' + self.assertEqual(client_objs, res) + + def test_reconcile_specs_empty(self): + base_spec = k8s.V1PodSpec(containers=[]) + client_spec = None + res = PodGenerator.reconcile_specs(base_spec, client_spec) + self.assertEqual(base_spec, res) + + base_spec = None + client_spec = k8s.V1PodSpec(containers=[]) + res = PodGenerator.reconcile_specs(base_spec, client_spec) + self.assertEqual(client_spec, res) + + def test_reconcile_specs(self): + base_objs = [k8s.V1Container(name='base_container1', image='base_image')] + client_objs = [k8s.V1Container(name='client_container1')] + base_spec = k8s.V1PodSpec(priority=1, active_deadline_seconds=100, containers=base_objs) + client_spec = k8s.V1PodSpec(priority=2, hostname='local', containers=client_objs) + res = PodGenerator.reconcile_specs(base_spec, client_spec) + client_spec.containers = [k8s.V1Container(name='client_container1', image='base_image')] + client_spec.active_deadline_seconds = 100 + self.assertEqual(client_spec, res) diff --git a/tests/kubernetes/test_worker_configuration.py b/tests/kubernetes/test_worker_configuration.py index 8378f9f..74009a1 100644 --- a/tests/kubernetes/test_worker_configuration.py +++ b/tests/kubernetes/test_worker_configuration.py @@ -17,13 +17,12 @@ # import unittest -import uuid -from datetime import datetime import six from tests.compat import mock from tests.test_utils.config import conf_vars + try: from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler from airflow.executors.kubernetes_executor import KubeConfig @@ -31,6 +30,7 @@ try: from airflow.kubernetes.pod_generator import PodGenerator from airflow.exceptions import AirflowConfigException from airflow.kubernetes.secret import Secret + from airflow.version import version as airflow_version import kubernetes.client.models as k8s from kubernetes.client.api_client import ApiClient except ImportError: @@ -74,6 +74,11 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): } ] + worker_annotations_config = { + 'iam.amazonaws.com/role': 'role-arn', + 'other/annotation': 'value' + } + def setUp(self): if AirflowKubernetesScheduler is None: self.skipTest("kubernetes python package is not installed") @@ -312,11 +317,39 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): self.kube_config.git_subpath = 'path' worker_config = WorkerConfiguration(self.kube_config) - pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id", - "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'") + pod = worker_config.as_pod() self.assertEqual(0, pod.spec.security_context.run_as_user) + def test_make_pod_assert_labels(self): + # Tests the pod created has all the expected labels set + self.kube_config.dags_folder = 'dags' + + worker_config = WorkerConfiguration(self.kube_config) + pod = PodGenerator.construct_pod( + "test_dag_id", + "test_task_id", + "test_pod_id", + 1, + "2019-11-21 11:08:22.920875", + ["bash -c 'ls /'"], + None, + worker_config.as_pod(), + "default", + "sample-uuid", + + ) + expected_labels = { + 'airflow-worker': 'sample-uuid', + 'airflow_version': airflow_version.replace('+', '-'), + 'dag_id': 'test_dag_id', + 'execution_date': '2019-11-21 11:08:22.920875', + 'kubernetes_executor': 'True', + 'task_id': 'test_task_id', + 'try_number': '1' + } + self.assertEqual(pod.metadata.labels, expected_labels) + def test_make_pod_git_sync_ssh_without_known_hosts(self): # Tests the pod created with git-sync SSH authentication option is correct without known hosts self.kube_config.airflow_configmap = 'airflow-configmap' @@ -331,8 +364,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): worker_config = WorkerConfiguration(self.kube_config) - pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id", - "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'") + pod = worker_config.as_pod() init_containers = worker_config._get_init_containers() git_ssh_key_file = next((x.value for x in init_containers[0].env @@ -361,8 +393,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): worker_config = WorkerConfiguration(self.kube_config) - pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id", - "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'") + pod = worker_config.as_pod() username_env = k8s.V1EnvVar( name='GIT_SYNC_USERNAME', @@ -387,6 +418,29 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): self.assertIn(password_env, pod.spec.init_containers[0].env, 'The password env for git credentials did not get into the init container') + def test_make_pod_git_sync_rev(self): + # Tests the pod created with git_sync_credentials_secret will get into the init container + self.kube_config.git_sync_rev = 'sampletag' + self.kube_config.dags_volume_claim = None + self.kube_config.dags_volume_host = None + self.kube_config.dags_in_image = None + self.kube_config.worker_fs_group = None + self.kube_config.git_dags_folder_mount_point = 'dags' + self.kube_config.git_sync_dest = 'repo' + self.kube_config.git_subpath = 'path' + + worker_config = WorkerConfiguration(self.kube_config) + + pod = worker_config.as_pod() + + rev_env = k8s.V1EnvVar( + name='GIT_SYNC_REV', + value=self.kube_config.git_sync_rev, + ) + + self.assertIn(rev_env, pod.spec.init_containers[0].env, + 'The git_sync_rev env did not get into the init container') + def test_make_pod_git_sync_ssh_with_known_hosts(self): # Tests the pod created with git-sync SSH authentication option is correct with known hosts self.kube_config.airflow_configmap = 'airflow-configmap' @@ -415,11 +469,10 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): def test_make_pod_with_empty_executor_config(self): self.kube_config.kube_affinity = self.affinity_config self.kube_config.kube_tolerations = self.tolerations_config + self.kube_config.kube_annotations = self.worker_annotations_config self.kube_config.dags_folder = 'dags' worker_config = WorkerConfiguration(self.kube_config) - - pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id", - "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'") + pod = worker_config.as_pod() self.assertTrue(pod.spec.affinity['podAntiAffinity'] is not None) self.assertEqual('app', @@ -431,6 +484,8 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): self.assertEqual(2, len(pod.spec.tolerations)) self.assertEqual('prod', pod.spec.tolerations[1]['key']) + self.assertEqual('role-arn', pod.metadata.annotations['iam.amazonaws.com/role']) + self.assertEqual('value', pod.metadata.annotations['other/annotation']) def test_make_pod_with_executor_config(self): self.kube_config.dags_folder = 'dags' @@ -441,8 +496,7 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): tolerations=self.tolerations_config, ).gen_pod() - pod = worker_config.make_pod("default", str(uuid.uuid4()), "test_pod_id", "test_dag_id", - "test_task_id", str(datetime.utcnow()), 1, "bash -c 'ls /'") + pod = worker_config.as_pod() result = PodGenerator.reconcile_pods(pod, config_pod) @@ -607,3 +661,18 @@ class TestKubernetesWorkerConfiguration(unittest.TestCase): 'dag_id': 'override_dag_id', 'my_kube_executor_label': 'kubernetes' }, labels) + + def test_make_pod_with_image_pull_secrets(self): + # Tests the pod created with image_pull_secrets actually gets that in it's config + self.kube_config.dags_volume_claim = None + self.kube_config.dags_volume_host = None + self.kube_config.dags_in_image = None + self.kube_config.git_dags_folder_mount_point = 'dags' + self.kube_config.git_sync_dest = 'repo' + self.kube_config.git_subpath = 'path' + self.kube_config.image_pull_secrets = 'image_pull_secret1,image_pull_secret2' + + worker_config = WorkerConfiguration(self.kube_config) + pod = worker_config.as_pod() + + self.assertEqual(2, len(pod.spec.image_pull_secrets))
