This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b3ce116192 Fix async KPO by waiting pod termination in 
`execute_complete` before cleanup (#32467)
b3ce116192 is described below

commit b3ce1161926efb880c3f525ac0a031ab4812fb95
Author: Hussein Awala <[email protected]>
AuthorDate: Wed Jul 12 12:05:03 2023 +0200

    Fix async KPO by waiting pod termination in `execute_complete` before 
cleanup (#32467)
    
    * Fix async KPO by waiting pod termination in `execute_complete` before 
cleanup (#32467)
    
    ---------
    
    Signed-off-by: Hussein Awala <[email protected]>
---
 airflow/providers/cncf/kubernetes/operators/pod.py | 13 +--
 airflow/providers/cncf/kubernetes/triggers/pod.py  | 26 ++----
 .../cncf/kubernetes/operators/test_pod.py          | 99 ++++++++++++++++++++++
 .../providers/cncf/kubernetes/triggers/test_pod.py | 32 +------
 .../cloud/triggers/test_kubernetes_engine.py       | 12 +--
 5 files changed, 114 insertions(+), 68 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 124685e792..49940144b5 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -636,15 +636,11 @@ class KubernetesPodOperator(BaseOperator):
 
     def execute_complete(self, context: Context, event: dict, **kwargs):
         pod = None
-        remote_pod = None
         try:
             pod = self.hook.get_pod(
                 event["name"],
                 event["namespace"],
             )
-            # It is done to coincide with the current implementation of the 
general logic of the cleanup
-            # method. If it's going to be remade in future then it must be 
changed
-            remote_pod = pod
             if event["status"] in ("error", "failed", "timeout"):
                 # fetch some logs when pod is failed
                 if self.get_logs:
@@ -661,16 +657,13 @@ class KubernetesPodOperator(BaseOperator):
 
                 if self.do_xcom_push:
                     xcom_sidecar_output = self.extract_xcom(pod=pod)
-                    pod = self.pod_manager.await_pod_completion(pod)
-                    # It is done to coincide with the current implementation 
of the general logic of
-                    # the cleanup method. If it's going to be remade in future 
then it must be changed
-                    remote_pod = pod
                     return xcom_sidecar_output
         finally:
-            if pod is not None and remote_pod is not None:
+            pod = self.pod_manager.await_pod_completion(pod)
+            if pod is not None:
                 self.post_complete_action(
                     pod=pod,
-                    remote_pod=remote_pod,
+                    remote_pod=pod,
                 )
 
     def write_logs(self, pod: k8s.V1Pod):
diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py 
b/airflow/providers/cncf/kubernetes/triggers/pod.py
index 6fdf763ece..6443dfb63f 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -154,23 +154,15 @@ class KubernetesPodTrigger(BaseTrigger):
                 self.log.debug("Container %s status: %s", 
self.base_container_name, container_state)
 
                 if container_state == ContainerState.TERMINATED:
-                    if pod_status not in PodPhase.terminal_states:
-                        self.log.info(
-                            "Pod %s is still running. Sleeping for %s 
seconds.",
-                            self.pod_name,
-                            self.poll_interval,
-                        )
-                        await asyncio.sleep(self.poll_interval)
-                    else:
-                        yield TriggerEvent(
-                            {
-                                "name": self.pod_name,
-                                "namespace": self.pod_namespace,
-                                "status": "success",
-                                "message": "All containers inside pod have 
started successfully.",
-                            }
-                        )
-                        return
+                    yield TriggerEvent(
+                        {
+                            "name": self.pod_name,
+                            "namespace": self.pod_namespace,
+                            "status": "success",
+                            "message": "All containers inside pod have started 
successfully.",
+                        }
+                    )
+                    return
                 elif self.should_wait(pod_phase=pod_status, 
container_state=container_state):
                     self.log.info("Container is not completed and still 
working.")
 
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 1ce25190ba..15ca6553cd 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1389,10 +1389,12 @@ class TestKubernetesPodOperatorAsync:
             ({"skip_on_exit_code": None}, 100, AirflowException, "Failed", 
"error"),
         ],
     )
+    @patch(KUB_OP_PATH.format("pod_manager"))
     @patch(HOOK_CLASS)
     def test_async_create_pod_with_skip_on_exit_code_should_skip(
         self,
         mocked_hook,
+        mock_manager,
         extra_kwargs,
         actual_exit_code,
         expected_exc,
@@ -1426,6 +1428,7 @@ class TestKubernetesPodOperatorAsync:
         remote_pod.status.phase = pod_status
         remote_pod.status.container_statuses = [base_container, 
sidecar_container]
         mocked_hook.return_value.get_pod.return_value = remote_pod
+        mock_manager.await_pod_completion.return_value = remote_pod
 
         context = {
             "ti": MagicMock(),
@@ -1608,3 +1611,99 @@ class TestKubernetesPodOperatorAsync:
         pod.status = V1PodStatus(phase=PodPhase.FAILED)
         with pytest.raises(AirflowException, match=expect_match):
             k.cleanup(pod, pod)
+
+
[email protected]("do_xcom_push", [True, False])
+@patch(KUB_OP_PATH.format("extract_xcom"))
+@patch(KUB_OP_PATH.format("post_complete_action"))
+@patch(HOOK_CLASS)
+def test_async_kpo_wait_termination_before_cleanup_on_success(
+    mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
+):
+    metadata = {"metadata.name": TEST_NAME, "metadata.namespace": 
TEST_NAMESPACE}
+    running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
+    succeeded_state = mock.MagicMock(**metadata, **{"status.phase": 
"Succeeded"})
+    mocked_hook.return_value.get_pod.return_value = running_state
+    read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
+    read_pod_mock.side_effect = [
+        running_state,
+        running_state,
+        succeeded_state,
+    ]
+
+    ti_mock = MagicMock()
+
+    success_event = {
+        "status": "success",
+        "message": TEST_SUCCESS_MESSAGE,
+        "name": TEST_NAME,
+        "namespace": TEST_NAMESPACE,
+    }
+
+    k = KubernetesPodOperator(task_id="task", deferrable=True, 
do_xcom_push=do_xcom_push)
+    k.execute_complete({"ti": ti_mock}, success_event)
+
+    # check if it gets the pod
+    mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, 
TEST_NAMESPACE)
+
+    # check if it pushes the xcom
+    assert ti_mock.xcom_push.call_count == 2
+    ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
+    ti_mock.xcom_push.assert_any_call(key="pod_namespace", 
value=TEST_NAMESPACE)
+
+    # assert that the xcom are extracted/not extracted
+    if do_xcom_push:
+        mock_extract_xcom.assert_called_once()
+    else:
+        mock_extract_xcom.assert_not_called()
+
+    # check if it waits for the pod to complete
+    assert read_pod_mock.call_count == 3
+
+    # assert that the cleanup is called
+    post_complete_action.assert_called_once()
+
+
[email protected]("do_xcom_push", [True, False])
+@patch(KUB_OP_PATH.format("extract_xcom"))
+@patch(KUB_OP_PATH.format("post_complete_action"))
+@patch(HOOK_CLASS)
+def test_async_kpo_wait_termination_before_cleanup_on_failure(
+    mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
+):
+    metadata = {"metadata.name": TEST_NAME, "metadata.namespace": 
TEST_NAMESPACE}
+    running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
+    failed_state = mock.MagicMock(**metadata, **{"status.phase": "Failed"})
+    mocked_hook.return_value.get_pod.return_value = running_state
+    read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
+    read_pod_mock.side_effect = [
+        running_state,
+        running_state,
+        failed_state,
+    ]
+
+    ti_mock = MagicMock()
+
+    success_event = {"status": "failed", "message": "error", "name": 
TEST_NAME, "namespace": TEST_NAMESPACE}
+
+    post_complete_action.side_effect = AirflowException()
+
+    k = KubernetesPodOperator(task_id="task", deferrable=True, 
do_xcom_push=do_xcom_push)
+
+    with pytest.raises(AirflowException):
+        k.execute_complete({"ti": ti_mock}, success_event)
+
+    # check if it gets the pod
+    mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, 
TEST_NAMESPACE)
+
+    # assert that it does not push the xcom
+    ti_mock.xcom_push.assert_not_called()
+
+    # assert that the xcom are not extracted
+    mock_extract_xcom.assert_not_called()
+
+    # check if it waits for the pod to complete
+    assert read_pod_mock.call_count == 3
+
+    # assert that the cleanup is called
+    post_complete_action.assert_called_once()
diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py 
b/tests/providers/cncf/kubernetes/triggers/test_pod.py
index 4ed731b425..fbfff17278 100644
--- a/tests/providers/cncf/kubernetes/triggers/test_pod.py
+++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py
@@ -96,8 +96,7 @@ class TestKubernetesPodTrigger:
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
     @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
     async def test_run_loop_return_success_event(self, mock_hook, mock_method, 
trigger):
-        pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"})
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(pod_mock)
+        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.TERMINATED
 
         expected_event = TriggerEvent(
@@ -112,35 +111,6 @@ class TestKubernetesPodTrigger:
 
         assert actual_event == expected_event
 
-    @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
-    async def 
test_run_loop_wait_pod_termination_before_returning_success_event(
-        self, mock_hook, mock_method, trigger
-    ):
-        running_state = mock.MagicMock(**{"status.phase": "Running"})
-        succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
-        mock_hook.return_value.get_pod.side_effect = [
-            self._mock_pod_result(running_state),
-            self._mock_pod_result(running_state),
-            self._mock_pod_result(succeeded_state),
-        ]
-        mock_method.return_value = ContainerState.TERMINATED
-
-        expected_event = TriggerEvent(
-            {
-                "name": POD_NAME,
-                "namespace": NAMESPACE,
-                "status": "success",
-                "message": "All containers inside pod have started 
successfully.",
-            }
-        )
-        with mock.patch.object(asyncio, "sleep") as mock_sleep:
-            actual_event = await (trigger.run()).asend(None)
-
-        assert actual_event == expected_event
-        assert mock_sleep.call_count == 2
-
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
     @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py 
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index 154908a6c4..e695822d38 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -110,13 +110,7 @@ class TestGKEStartPodTrigger:
     async def test_run_loop_return_success_event_should_execute_successfully(
         self, mock_hook, mock_method, trigger
     ):
-        running_state = mock.MagicMock(**{"status.phase": "Running"})
-        succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
-        mock_hook.return_value.get_pod.side_effect = [
-            self._mock_pod_result(running_state),
-            self._mock_pod_result(running_state),
-            self._mock_pod_result(succeeded_state),
-        ]
+        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.TERMINATED
 
         expected_event = TriggerEvent(
@@ -127,11 +121,9 @@ class TestGKEStartPodTrigger:
                 "message": "All containers inside pod have started 
successfully.",
             }
         )
-        with mock.patch.object(asyncio, "sleep") as mock_sleep:
-            actual_event = await (trigger.run()).asend(None)
+        actual_event = await (trigger.run()).asend(None)
 
         assert actual_event == expected_event
-        assert mock_sleep.call_count == 2
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")

Reply via email to