This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b3ce116192 Fix async KPO by waiting pod termination in
`execute_complete` before cleanup (#32467)
b3ce116192 is described below
commit b3ce1161926efb880c3f525ac0a031ab4812fb95
Author: Hussein Awala <[email protected]>
AuthorDate: Wed Jul 12 12:05:03 2023 +0200
Fix async KPO by waiting pod termination in `execute_complete` before
cleanup (#32467)
* Fix async KPO by waiting pod termination in `execute_complete` before
cleanup (#32467)
---------
Signed-off-by: Hussein Awala <[email protected]>
---
airflow/providers/cncf/kubernetes/operators/pod.py | 13 +--
airflow/providers/cncf/kubernetes/triggers/pod.py | 26 ++----
.../cncf/kubernetes/operators/test_pod.py | 99 ++++++++++++++++++++++
.../providers/cncf/kubernetes/triggers/test_pod.py | 32 +------
.../cloud/triggers/test_kubernetes_engine.py | 12 +--
5 files changed, 114 insertions(+), 68 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 124685e792..49940144b5 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -636,15 +636,11 @@ class KubernetesPodOperator(BaseOperator):
def execute_complete(self, context: Context, event: dict, **kwargs):
pod = None
- remote_pod = None
try:
pod = self.hook.get_pod(
event["name"],
event["namespace"],
)
- # It is done to coincide with the current implementation of the
general logic of the cleanup
- # method. If it's going to be remade in future then it must be
changed
- remote_pod = pod
if event["status"] in ("error", "failed", "timeout"):
# fetch some logs when pod is failed
if self.get_logs:
@@ -661,16 +657,13 @@ class KubernetesPodOperator(BaseOperator):
if self.do_xcom_push:
xcom_sidecar_output = self.extract_xcom(pod=pod)
- pod = self.pod_manager.await_pod_completion(pod)
- # It is done to coincide with the current implementation
of the general logic of
- # the cleanup method. If it's going to be remade in future
then it must be changed
- remote_pod = pod
return xcom_sidecar_output
finally:
- if pod is not None and remote_pod is not None:
+ pod = self.pod_manager.await_pod_completion(pod)
+ if pod is not None:
self.post_complete_action(
pod=pod,
- remote_pod=remote_pod,
+ remote_pod=pod,
)
def write_logs(self, pod: k8s.V1Pod):
diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py
b/airflow/providers/cncf/kubernetes/triggers/pod.py
index 6fdf763ece..6443dfb63f 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -154,23 +154,15 @@ class KubernetesPodTrigger(BaseTrigger):
self.log.debug("Container %s status: %s",
self.base_container_name, container_state)
if container_state == ContainerState.TERMINATED:
- if pod_status not in PodPhase.terminal_states:
- self.log.info(
- "Pod %s is still running. Sleeping for %s
seconds.",
- self.pod_name,
- self.poll_interval,
- )
- await asyncio.sleep(self.poll_interval)
- else:
- yield TriggerEvent(
- {
- "name": self.pod_name,
- "namespace": self.pod_namespace,
- "status": "success",
- "message": "All containers inside pod have
started successfully.",
- }
- )
- return
+ yield TriggerEvent(
+ {
+ "name": self.pod_name,
+ "namespace": self.pod_namespace,
+ "status": "success",
+ "message": "All containers inside pod have started
successfully.",
+ }
+ )
+ return
elif self.should_wait(pod_phase=pod_status,
container_state=container_state):
self.log.info("Container is not completed and still
working.")
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 1ce25190ba..15ca6553cd 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1389,10 +1389,12 @@ class TestKubernetesPodOperatorAsync:
({"skip_on_exit_code": None}, 100, AirflowException, "Failed",
"error"),
],
)
+ @patch(KUB_OP_PATH.format("pod_manager"))
@patch(HOOK_CLASS)
def test_async_create_pod_with_skip_on_exit_code_should_skip(
self,
mocked_hook,
+ mock_manager,
extra_kwargs,
actual_exit_code,
expected_exc,
@@ -1426,6 +1428,7 @@ class TestKubernetesPodOperatorAsync:
remote_pod.status.phase = pod_status
remote_pod.status.container_statuses = [base_container,
sidecar_container]
mocked_hook.return_value.get_pod.return_value = remote_pod
+ mock_manager.await_pod_completion.return_value = remote_pod
context = {
"ti": MagicMock(),
@@ -1608,3 +1611,99 @@ class TestKubernetesPodOperatorAsync:
pod.status = V1PodStatus(phase=PodPhase.FAILED)
with pytest.raises(AirflowException, match=expect_match):
k.cleanup(pod, pod)
+
+
[email protected]("do_xcom_push", [True, False])
+@patch(KUB_OP_PATH.format("extract_xcom"))
+@patch(KUB_OP_PATH.format("post_complete_action"))
+@patch(HOOK_CLASS)
+def test_async_kpo_wait_termination_before_cleanup_on_success(
+ mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
+):
+ metadata = {"metadata.name": TEST_NAME, "metadata.namespace":
TEST_NAMESPACE}
+ running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
+ succeeded_state = mock.MagicMock(**metadata, **{"status.phase":
"Succeeded"})
+ mocked_hook.return_value.get_pod.return_value = running_state
+ read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
+ read_pod_mock.side_effect = [
+ running_state,
+ running_state,
+ succeeded_state,
+ ]
+
+ ti_mock = MagicMock()
+
+ success_event = {
+ "status": "success",
+ "message": TEST_SUCCESS_MESSAGE,
+ "name": TEST_NAME,
+ "namespace": TEST_NAMESPACE,
+ }
+
+ k = KubernetesPodOperator(task_id="task", deferrable=True,
do_xcom_push=do_xcom_push)
+ k.execute_complete({"ti": ti_mock}, success_event)
+
+ # check if it gets the pod
+ mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME,
TEST_NAMESPACE)
+
+ # check if it pushes the xcom
+ assert ti_mock.xcom_push.call_count == 2
+ ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
+ ti_mock.xcom_push.assert_any_call(key="pod_namespace",
value=TEST_NAMESPACE)
+
+ # assert that the xcom are extracted/not extracted
+ if do_xcom_push:
+ mock_extract_xcom.assert_called_once()
+ else:
+ mock_extract_xcom.assert_not_called()
+
+ # check if it waits for the pod to complete
+ assert read_pod_mock.call_count == 3
+
+ # assert that the cleanup is called
+ post_complete_action.assert_called_once()
+
+
[email protected]("do_xcom_push", [True, False])
+@patch(KUB_OP_PATH.format("extract_xcom"))
+@patch(KUB_OP_PATH.format("post_complete_action"))
+@patch(HOOK_CLASS)
+def test_async_kpo_wait_termination_before_cleanup_on_failure(
+ mocked_hook, post_complete_action, mock_extract_xcom, do_xcom_push
+):
+ metadata = {"metadata.name": TEST_NAME, "metadata.namespace":
TEST_NAMESPACE}
+ running_state = mock.MagicMock(**metadata, **{"status.phase": "Running"})
+ failed_state = mock.MagicMock(**metadata, **{"status.phase": "Failed"})
+ mocked_hook.return_value.get_pod.return_value = running_state
+ read_pod_mock = mocked_hook.return_value.core_v1_client.read_namespaced_pod
+ read_pod_mock.side_effect = [
+ running_state,
+ running_state,
+ failed_state,
+ ]
+
+ ti_mock = MagicMock()
+
+ success_event = {"status": "failed", "message": "error", "name":
TEST_NAME, "namespace": TEST_NAMESPACE}
+
+ post_complete_action.side_effect = AirflowException()
+
+ k = KubernetesPodOperator(task_id="task", deferrable=True,
do_xcom_push=do_xcom_push)
+
+ with pytest.raises(AirflowException):
+ k.execute_complete({"ti": ti_mock}, success_event)
+
+ # check if it gets the pod
+ mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME,
TEST_NAMESPACE)
+
+ # assert that it does not push the xcom
+ ti_mock.xcom_push.assert_not_called()
+
+ # assert that the xcom are not extracted
+ mock_extract_xcom.assert_not_called()
+
+ # check if it waits for the pod to complete
+ assert read_pod_mock.call_count == 3
+
+ # assert that the cleanup is called
+ post_complete_action.assert_called_once()
diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py
b/tests/providers/cncf/kubernetes/triggers/test_pod.py
index 4ed731b425..fbfff17278 100644
--- a/tests/providers/cncf/kubernetes/triggers/test_pod.py
+++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py
@@ -96,8 +96,7 @@ class TestKubernetesPodTrigger:
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
async def test_run_loop_return_success_event(self, mock_hook, mock_method,
trigger):
- pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"})
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(pod_mock)
+ mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED
expected_event = TriggerEvent(
@@ -112,35 +111,6 @@ class TestKubernetesPodTrigger:
assert actual_event == expected_event
- @pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
- async def
test_run_loop_wait_pod_termination_before_returning_success_event(
- self, mock_hook, mock_method, trigger
- ):
- running_state = mock.MagicMock(**{"status.phase": "Running"})
- succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
- mock_hook.return_value.get_pod.side_effect = [
- self._mock_pod_result(running_state),
- self._mock_pod_result(running_state),
- self._mock_pod_result(succeeded_state),
- ]
- mock_method.return_value = ContainerState.TERMINATED
-
- expected_event = TriggerEvent(
- {
- "name": POD_NAME,
- "namespace": NAMESPACE,
- "status": "success",
- "message": "All containers inside pod have started
successfully.",
- }
- )
- with mock.patch.object(asyncio, "sleep") as mock_sleep:
- actual_event = await (trigger.run()).asend(None)
-
- assert actual_event == expected_event
- assert mock_sleep.call_count == 2
-
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
@mock.patch(f"{TRIGGER_PATH}._get_async_hook")
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index 154908a6c4..e695822d38 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -110,13 +110,7 @@ class TestGKEStartPodTrigger:
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
- running_state = mock.MagicMock(**{"status.phase": "Running"})
- succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"})
- mock_hook.return_value.get_pod.side_effect = [
- self._mock_pod_result(running_state),
- self._mock_pod_result(running_state),
- self._mock_pod_result(succeeded_state),
- ]
+ mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED
expected_event = TriggerEvent(
@@ -127,11 +121,9 @@ class TestGKEStartPodTrigger:
"message": "All containers inside pod have started
successfully.",
}
)
- with mock.patch.object(asyncio, "sleep") as mock_sleep:
- actual_event = await (trigger.run()).asend(None)
+ actual_event = await (trigger.run()).asend(None)
assert actual_event == expected_event
- assert mock_sleep.call_count == 2
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")