This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 0be6430938 Revert "KPO Maintain backward compatibility for execute_complete and trigger run method (#37363)" (#37446) 0be6430938 is described below commit 0be643093879e106f7ee1e41c155954edd14398f Author: Jarek Potiuk <ja...@potiuk.com> AuthorDate: Thu Feb 15 14:58:36 2024 +0100 Revert "KPO Maintain backward compatibility for execute_complete and trigger run method (#37363)" (#37446) This reverts commit 0640e6d595c01dd96f2b90812a546bc091f87743. --- airflow/providers/cncf/kubernetes/operators/pod.py | 150 ++++++++++++--------- airflow/providers/cncf/kubernetes/triggers/pod.py | 70 +++------- .../cncf/kubernetes/operators/test_pod.py | 34 ++--- .../providers/cncf/kubernetes/triggers/test_pod.py | 92 ++++++------- .../cloud/triggers/test_kubernetes_engine.py | 51 ++++--- 5 files changed, 189 insertions(+), 208 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 61442a6014..73389f4038 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -18,7 +18,6 @@ from __future__ import annotations -import datetime import json import logging import re @@ -31,7 +30,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import kubernetes -from deprecated import deprecated from kubernetes.client import CoreV1Api, V1Pod, models as k8s from kubernetes.stream import stream from urllib3.exceptions import HTTPError @@ -70,6 +68,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import ( EMPTY_XCOM_RESULT, OnFinishAction, PodLaunchFailedException, + PodLaunchTimeoutException, PodManager, PodNotFoundException, PodOperatorHookProtocol, @@ -80,6 +79,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import ( from airflow.settings import pod_mutation_hook from airflow.utils import yaml from airflow.utils.helpers import prune_dict, validate_key +from airflow.utils.timezone import utcnow from airflow.version import version as airflow_version if TYPE_CHECKING: @@ -656,7 +656,7 @@ class KubernetesPodOperator(BaseOperator): def invoke_defer_method(self, last_log_time: DateTime | None = None): """Redefine triggers which are being used in child classes.""" - trigger_start_time = datetime.datetime.now(tz=datetime.timezone.utc) + trigger_start_time = utcnow() self.defer( trigger=KubernetesPodTrigger( pod_name=self.pod.metadata.name, # type: ignore[union-attr] @@ -678,87 +678,117 @@ class KubernetesPodOperator(BaseOperator): method_name="trigger_reentry", ) + @staticmethod + def raise_for_trigger_status(event: dict[str, Any]) -> None: + """Raise exception if pod is not in expected state.""" + if event["status"] == "error": + error_type = event["error_type"] + description = event["description"] + if error_type == "PodLaunchTimeoutException": + raise PodLaunchTimeoutException(description) + else: + raise AirflowException(description) + def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: """ Point of re-entry from trigger. - If ``logging_interval`` is None, then at this point, the pod should be done, and we'll just fetch + If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch the logs and exit. - If ``logging_interval`` is not None, it could be that the pod is still running, and we'll just + If ``logging_interval`` is not None, it could be that the pod is still running and we'll just grab the latest logs and defer back to the trigger again. """ - self.pod = None + remote_pod = None try: - pod_name = event["name"] - pod_namespace = event["namespace"] + self.pod_request_obj = self.build_pod_request_obj(context) + self.pod = self.find_pod( + namespace=self.namespace or self.pod_request_obj.metadata.namespace, + context=context, + ) - self.pod = self.hook.get_pod(pod_name, pod_namespace) + # we try to find pod before possibly raising so that on_kill will have `pod` attr + self.raise_for_trigger_status(event) if not self.pod: raise PodNotFoundException("Could not find pod after resuming from deferral") - if self.callbacks and event["status"] != "running": - self.callbacks.on_operator_resuming( - pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC + if self.get_logs: + last_log_time = event and event.get("last_log_time") + if last_log_time: + self.log.info("Resuming logs read from time %r", last_log_time) + pod_log_status = self.pod_manager.fetch_container_logs( + pod=self.pod, + container_name=self.BASE_CONTAINER_NAME, + follow=self.logging_interval is None, + since_time=last_log_time, ) + if pod_log_status.running: + self.log.info("Container still running; deferring again.") + self.invoke_defer_method(pod_log_status.last_log_time) + + if self.do_xcom_push: + result = self.extract_xcom(pod=self.pod) + remote_pod = self.pod_manager.await_pod_completion(self.pod) + except TaskDeferred: + raise + except Exception: + self.cleanup( + pod=self.pod or self.pod_request_obj, + remote_pod=remote_pod, + ) + raise + self.cleanup( + pod=self.pod or self.pod_request_obj, + remote_pod=remote_pod, + ) + if self.do_xcom_push: + return result + def execute_complete(self, context: Context, event: dict, **kwargs): + self.log.debug("Triggered with event: %s", event) + pod = None + try: + pod = self.hook.get_pod( + event["name"], + event["namespace"], + ) + if self.callbacks: + self.callbacks.on_operator_resuming( + pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC + ) if event["status"] in ("error", "failed", "timeout"): + # fetch some logs when pod is failed + if self.get_logs: + self.write_logs(pod) + if "stack_trace" in event: + message = f"{event['message']}\n{event['stack_trace']}" + else: + message = event["message"] if self.do_xcom_push: - _ = self.extract_xcom(pod=self.pod) - - message = event.get("stack_trace", event["message"]) + # In the event of base container failure, we need to kill the xcom sidecar. + # We disregard xcom output and do that here + _ = self.extract_xcom(pod=pod) raise AirflowException(message) - - elif event["status"] == "running": + elif event["status"] == "success": + # fetch some logs when pod is executed successfully if self.get_logs: - last_log_time = event.get("last_log_time") - self.log.info("Resuming logs read from time %r", last_log_time) - - pod_log_status = self.pod_manager.fetch_container_logs( - pod=self.pod, - container_name=self.BASE_CONTAINER_NAME, - follow=self.logging_interval is None, - since_time=last_log_time, - ) + self.write_logs(pod) - if pod_log_status.running: - self.log.info("Container still running; deferring again.") - self.invoke_defer_method(pod_log_status.last_log_time) - else: - self.invoke_defer_method() - - elif event["status"] == "success": if self.do_xcom_push: - xcom_sidecar_output = self.extract_xcom(pod=self.pod) + xcom_sidecar_output = self.extract_xcom(pod=pod) return xcom_sidecar_output - return - except TaskDeferred: - raise finally: - self._clean(event) - - def _clean(self, event: dict[str, Any]): - if event["status"] == "running": - return - if self.get_logs: - self.write_logs(self.pod) - istio_enabled = self.is_istio_enabled(self.pod) - # Skip await_pod_completion when the event is 'timeout' due to the pod can hang - # on the ErrImagePull or ContainerCreating step and it will never complete - if event["status"] != "timeout": - self.pod = self.pod_manager.await_pod_completion( - self.pod, istio_enabled, self.base_container_name - ) - if self.pod is not None: - self.post_complete_action( - pod=self.pod, - remote_pod=self.pod, - ) - - @deprecated(reason="use `trigger_reentry` instead.", category=AirflowProviderDeprecationWarning) - def execute_complete(self, context: Context, event: dict, **kwargs): - self.trigger_reentry(context=context, event=event) + istio_enabled = self.is_istio_enabled(pod) + # Skip await_pod_completion when the event is 'timeout' due to the pod can hang + # on the ErrImagePull or ContainerCreating step and it will never complete + if event["status"] != "timeout": + pod = self.pod_manager.await_pod_completion(pod, istio_enabled, self.base_container_name) + if pod is not None: + self.post_complete_action( + pod=pod, + remote_pod=pod, + ) def write_logs(self, pod: k8s.V1Pod): try: diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index c9b1e62226..e34a73f146 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -30,8 +30,10 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import ( OnFinishAction, PodLaunchTimeoutException, PodPhase, + container_is_running, ) from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils import timezone if TYPE_CHECKING: from kubernetes_asyncio.client.models import V1Pod @@ -158,49 +160,22 @@ class KubernetesPodTrigger(BaseTrigger): self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) try: state = await self._wait_for_pod_start() - if state == ContainerState.TERMINATED: + if state in PodPhase.terminal_states: event = TriggerEvent( - { - "status": "success", - "namespace": self.pod_namespace, - "name": self.pod_name, - "message": "All containers inside pod have started successfully.", - } - ) - elif state == ContainerState.FAILED: - event = TriggerEvent( - { - "status": "failed", - "namespace": self.pod_namespace, - "name": self.pod_name, - "message": "pod failed", - } + {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name} ) else: event = await self._wait_for_container_completion() yield event - return - except PodLaunchTimeoutException as e: - message = self._format_exception_description(e) - yield TriggerEvent( - { - "name": self.pod_name, - "namespace": self.pod_namespace, - "status": "timeout", - "message": message, - } - ) except Exception as e: + description = self._format_exception_description(e) yield TriggerEvent( { - "name": self.pod_name, - "namespace": self.pod_namespace, "status": "error", - "message": str(e), - "stack_trace": traceback.format_exc(), + "error_type": e.__class__.__name__, + "description": description, } ) - return def _format_exception_description(self, exc: Exception) -> Any: if isinstance(exc, PodLaunchTimeoutException): @@ -214,13 +189,14 @@ class KubernetesPodTrigger(BaseTrigger): description += f"\ntrigger traceback:\n{curr_traceback}" return description - async def _wait_for_pod_start(self) -> ContainerState: + async def _wait_for_pod_start(self) -> Any: """Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error.""" - delta = datetime.datetime.now(tz=datetime.timezone.utc) - self.trigger_start_time - while self.startup_timeout >= delta.total_seconds(): + start_time = timezone.utcnow() + timeout_end = start_time + datetime.timedelta(seconds=self.startup_timeout) + while timeout_end > timezone.utcnow(): pod = await self.hook.get_pod(self.pod_name, self.pod_namespace) if not pod.status.phase == "Pending": - return self.define_container_state(pod) + return pod.status.phase self.log.info("Still waiting for pod to start. The pod state is %s", pod.status.phase) await asyncio.sleep(self.poll_interval) raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout") @@ -232,30 +208,18 @@ class KubernetesPodTrigger(BaseTrigger): Waits until container is no longer in running state. If trigger is configured with a logging period, then will emit an event to resume the task for the purpose of fetching more logs. """ - time_begin = datetime.datetime.now(tz=datetime.timezone.utc) + time_begin = timezone.utcnow() time_get_more_logs = None if self.logging_interval is not None: time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval) while True: pod = await self.hook.get_pod(self.pod_name, self.pod_namespace) - container_state = self.define_container_state(pod) - if container_state == ContainerState.TERMINATED: - return TriggerEvent( - {"status": "success", "namespace": self.pod_namespace, "name": self.pod_name} - ) - elif container_state == ContainerState.FAILED: - return TriggerEvent( - {"status": "failed", "namespace": self.pod_namespace, "name": self.pod_name} - ) - if time_get_more_logs and datetime.datetime.now(tz=datetime.timezone.utc) > time_get_more_logs: + if not container_is_running(pod=pod, container_name=self.base_container_name): return TriggerEvent( - { - "status": "running", - "last_log_time": self.last_log_time, - "namespace": self.pod_namespace, - "name": self.pod_name, - } + {"status": "done", "namespace": self.pod_namespace, "pod_name": self.pod_name} ) + if time_get_more_logs and timezone.utcnow() > time_get_more_logs: + return TriggerEvent({"status": "running", "last_log_time": self.last_log_time}) await asyncio.sleep(self.poll_interval) def _get_async_hook(self) -> AsyncKubernetesHook: diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index faa21eb7d7..c27cd23146 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -35,6 +35,7 @@ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperato from airflow.providers.cncf.kubernetes.secret import Secret from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger from airflow.providers.cncf.kubernetes.utils.pod_manager import ( + PodLaunchTimeoutException, PodLoggingStatus, PodPhase, ) @@ -1972,39 +1973,41 @@ class TestKubernetesPodOperatorAsync: with pytest.raises(AirflowException, match=expect_match): k.cleanup(pod, pod) - @mock.patch(f"{HOOK_CLASS}.get_pod") + @mock.patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status" + ) + @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") def test_get_logs_running( self, fetch_container_logs, await_pod_completion, - get_pod, + find_pod, + raise_for_trigger_status, ): """When logs fetch exits with status running, raise task deferred""" pod = MagicMock() - get_pod.return_value = pod + find_pod.return_value = pod op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) await_pod_completion.return_value = None fetch_container_logs.return_value = PodLoggingStatus(True, None) with pytest.raises(TaskDeferred): - op.trigger_reentry( - create_context(op), - event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "running"}, - ) + op.trigger_reentry(create_context(op), None) fetch_container_logs.is_called_with(pod, "base") @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") + @mock.patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.raise_for_trigger_status" + ) @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.find_pod") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") - def test_get_logs_not_running(self, fetch_container_logs, find_pod, cleanup): + def test_get_logs_not_running(self, fetch_container_logs, find_pod, raise_for_trigger_status, cleanup): pod = MagicMock() find_pod.return_value = pod op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) fetch_container_logs.return_value = PodLoggingStatus(False, None) - op.trigger_reentry( - create_context(op), event={"name": TEST_NAME, "namespace": TEST_NAMESPACE, "status": "success"} - ) + op.trigger_reentry(create_context(op), None) fetch_container_logs.is_called_with(pod, "base") @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup") @@ -2013,15 +2016,14 @@ class TestKubernetesPodOperatorAsync: """Assert that trigger_reentry raise exception in case of error""" find_pod.return_value = MagicMock() op = KubernetesPodOperator(task_id="test_task", name="test-pod", get_logs=True) - with pytest.raises(AirflowException): + with pytest.raises(PodLaunchTimeoutException): context = create_context(op) op.trigger_reentry( context, { - "status": "timeout", - "message": "any message", - "name": TEST_NAME, - "namespace": TEST_NAMESPACE, + "status": "error", + "error_type": "PodLaunchTimeoutException", + "description": "any message", }, ) diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index bed52811fc..d12100e4e3 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -122,10 +122,9 @@ class TestKubernetesPodTrigger: expected_event = TriggerEvent( { - "status": "success", - "namespace": "default", - "name": "test-pod-name", - "message": "All containers inside pod have started successfully.", + "pod_name": POD_NAME, + "namespace": NAMESPACE, + "status": "done", } ) actual_event = await trigger.run().asend(None) @@ -133,11 +132,16 @@ class TestKubernetesPodTrigger: assert actual_event == expected_event @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") + @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_PATH}.hook") - async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, trigger, caplog): + async def test_run_loop_return_waiting_event( + self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog + ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.WAITING + mock_container_is_running.return_value = True caplog.set_level(logging.INFO) @@ -149,11 +153,16 @@ class TestKubernetesPodTrigger: assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") + @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_PATH}.hook") - async def test_run_loop_return_running_event(self, mock_hook, mock_method, trigger, caplog): + async def test_run_loop_return_running_event( + self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog + ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) mock_method.return_value = ContainerState.RUNNING + mock_container_is_running.return_value = True caplog.set_level(logging.INFO) @@ -178,7 +187,11 @@ class TestKubernetesPodTrigger: mock_method.return_value = ContainerState.FAILED expected_event = TriggerEvent( - {"status": "failed", "namespace": "default", "name": "test-pod-name", "message": "pod failed"} + { + "pod_name": POD_NAME, + "namespace": NAMESPACE, + "status": "done", + } ) actual_event = await trigger.run().asend(None) @@ -197,14 +210,8 @@ class TestKubernetesPodTrigger: generator = trigger.run() actual = await generator.asend(None) - actual_stack_trace = actual.payload.pop("stack_trace") - assert ( - TriggerEvent( - {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"} - ) - == actual - ) - assert actual_stack_trace.startswith("Traceback (most recent call last):") + actual_stack_trace = actual.payload.pop("description") + assert actual_stack_trace.startswith("Trigger KubernetesPodTrigger failed with exception Exception") @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") @@ -228,24 +235,16 @@ class TestKubernetesPodTrigger: @pytest.mark.parametrize( "logging_interval, exp_event", [ - param( - 0, - { - "status": "running", - "last_log_time": DateTime(2022, 1, 1), - "name": POD_NAME, - "namespace": NAMESPACE, - }, - id="short_interval", - ), + param(0, {"status": "running", "last_log_time": DateTime(2022, 1, 1)}, id="short_interval"), + param(None, {"status": "done", "namespace": mock.ANY, "pod_name": mock.ANY}, id="no_interval"), ], ) - @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") - @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.AsyncKubernetesHook.get_pod") - async def test_running_log_interval( - self, mock_get_pod, mock_wait_for_pod_start, define_container_state, logging_interval, exp_event - ): + @mock.patch( + "kubernetes_asyncio.client.CoreV1Api.read_namespaced_pod", + new=get_read_pod_mock_containers([1, 1, None, None]), + ) + @mock.patch("kubernetes_asyncio.config.load_kube_config") + async def test_running_log_interval(self, load_kube_config, logging_interval, exp_event): """ If log interval given, should emit event with running status and last log time. Otherwise, should make it to second loop and emit "done" event. @@ -255,15 +254,14 @@ class TestKubernetesPodTrigger: interval is None, the second "running" status will just result in continuation of the loop. And when in the next loop we get a non-running status, the trigger fires a "done" event. """ - define_container_state.return_value = "running" trigger = KubernetesPodTrigger( - pod_name=POD_NAME, - pod_namespace=NAMESPACE, - trigger_start_time=datetime.datetime.now(tz=datetime.timezone.utc), - base_container_name=BASE_CONTAINER_NAME, + pod_name=mock.ANY, + pod_namespace=mock.ANY, + trigger_start_time=mock.ANY, + base_container_name=mock.ANY, startup_timeout=5, poll_interval=1, - logging_interval=1, + logging_interval=logging_interval, last_log_time=DateTime(2022, 1, 1), ) assert await trigger.run().__anext__() == TriggerEvent(exp_event) @@ -308,12 +306,12 @@ class TestKubernetesPodTrigger: @pytest.mark.asyncio @pytest.mark.parametrize("container_state", [ContainerState.WAITING, ContainerState.UNDEFINED]) - @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_PATH}.hook") async def test_run_loop_return_timeout_event( self, mock_hook, mock_method, trigger, caplog, container_state ): - trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(minutes=2) + trigger.trigger_start_time = TRIGGER_START_TIME - datetime.timedelta(seconds=5) mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( status=mock.MagicMock( @@ -327,14 +325,4 @@ class TestKubernetesPodTrigger: generator = trigger.run() actual = await generator.asend(None) - assert ( - TriggerEvent( - { - "name": POD_NAME, - "namespace": NAMESPACE, - "status": "timeout", - "message": "Pod did not leave 'Pending' phase within specified timeout", - } - ) - == actual - ) + assert actual == TriggerEvent({"status": "done", "namespace": NAMESPACE, "pod_name": POD_NAME}) diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index c6a2d4e72f..ca7b7ba358 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -108,20 +108,19 @@ class TestGKEStartPodTrigger: } @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_success_event_should_execute_successfully( - self, mock_hook, mock_wait_pod, trigger + self, mock_hook, mock_method, trigger ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) - mock_wait_pod.return_value = ContainerState.TERMINATED + mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( { - "name": POD_NAME, + "pod_name": POD_NAME, "namespace": NAMESPACE, - "status": "success", - "message": "All containers inside pod have started successfully.", + "status": "done", } ) actual_event = await trigger.run().asend(None) @@ -129,10 +128,10 @@ class TestGKEStartPodTrigger: assert actual_event == expected_event @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") + @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_failed_event_should_execute_successfully( - self, mock_hook, mock_wait_pod, trigger + self, mock_hook, mock_method, trigger ): mock_hook.get_pod.return_value = self._mock_pod_result( mock.MagicMock( @@ -141,14 +140,13 @@ class TestGKEStartPodTrigger: ) ) ) - mock_wait_pod.return_value = ContainerState.FAILED + mock_method.return_value = ContainerState.FAILED expected_event = TriggerEvent( { - "name": POD_NAME, + "pod_name": POD_NAME, "namespace": NAMESPACE, - "status": "failed", - "message": "pod failed", + "status": "done", } ) actual_event = await trigger.run().asend(None) @@ -156,15 +154,18 @@ class TestGKEStartPodTrigger: assert actual_event == expected_event @pytest.mark.asyncio + @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") - @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_waiting_event_should_execute_successfully( - self, mock_hook, mock_method, mock_wait_pod, trigger, caplog + self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) - mock_method.return_value = ContainerState.WAITING + mock_method.return_value = ContainerState.RUNNING + mock_container_is_running.return_value = True + trigger.logging_interval = 10 caplog.set_level(logging.INFO) task = asyncio.create_task(trigger.run().__anext__()) @@ -175,13 +176,15 @@ class TestGKEStartPodTrigger: assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio + @mock.patch("airflow.providers.cncf.kubernetes.triggers.pod.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook.get_pod") @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") - @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_run_loop_return_running_event_should_execute_successfully( - self, mock_hook, mock_method, mock_wait_pod, trigger, caplog + self, mock_hook, mock_method, mock_get_pod, mock_container_is_running, trigger, caplog ): mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + mock_container_is_running.return_value = True mock_method.return_value = ContainerState.RUNNING caplog.set_level(logging.INFO) @@ -194,10 +197,9 @@ class TestGKEStartPodTrigger: assert f"Sleeping for {POLL_INTERVAL} seconds." @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_KUB_PATH}._wait_for_pod_start") @mock.patch(f"{TRIGGER_GKE_PATH}.hook") async def test_logging_in_trigger_when_exception_should_execute_successfully( - self, mock_hook, mock_wait_pod, trigger, caplog + self, mock_hook, trigger, caplog ): """ Test that GKEStartPodTrigger fires the correct event in case of an error. @@ -206,14 +208,9 @@ class TestGKEStartPodTrigger: generator = trigger.run() actual = await generator.asend(None) - actual_stack_trace = actual.payload.pop("stack_trace") - assert ( - TriggerEvent( - {"name": POD_NAME, "namespace": NAMESPACE, "status": "error", "message": "Test exception"} - ) - == actual - ) - assert actual_stack_trace.startswith("Traceback (most recent call last):") + + actual_stack_trace = actual.payload.pop("description") + assert actual_stack_trace.startswith("Trigger GKEStartPodTrigger failed with exception Exception") @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")