This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b435b8edef Push to xcom before `KubernetesPodOperator` deferral 
(#34209)
b435b8edef is described below

commit b435b8edefd181fa85e6cc6b2b822d113f562e27
Author: Victor Chiapaikeo <[email protected]>
AuthorDate: Sat Sep 9 14:25:34 2023 -0400

    Push to xcom before `KubernetesPodOperator` deferral (#34209)
---
 airflow/providers/cncf/kubernetes/operators/pod.py |  8 +--
 .../cncf/kubernetes/operators/test_pod.py          | 57 ++++++++--------------
 2 files changed, 23 insertions(+), 42 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 5f7269718f..5ec9e2decb 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -630,6 +630,10 @@ class KubernetesPodOperator(BaseOperator):
             pod_request_obj=self.pod_request_obj,
             context=context,
         )
+        ti = context["ti"]
+        ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
+        ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
+
         self.invoke_defer_method()
 
     def invoke_defer_method(self):
@@ -666,10 +670,6 @@ class KubernetesPodOperator(BaseOperator):
                     self.write_logs(pod)
                 raise AirflowException(event["message"])
             elif event["status"] == "success":
-                ti = context["ti"]
-                ti.xcom_push(key="pod_name", value=pod.metadata.name)
-                ti.xcom_push(key="pod_namespace", value=pod.metadata.namespace)
-
                 # fetch some logs when pod is executed successfully
                 if self.get_logs:
                     self.write_logs(pod)
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 3ada6218a5..fde4e60159 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1471,12 +1471,15 @@ class TestKubernetesPodOperatorAsync:
         )
         return remote_pod_mock
 
+    @pytest.mark.parametrize("do_xcom_push", [True, False])
     @patch(KUB_OP_PATH.format("build_pod_request_obj"))
     @patch(KUB_OP_PATH.format("get_or_create_pod"))
-    def test_async_create_pod_should_execute_successfully(self, mocked_pod, 
mocked_pod_obj):
+    def test_async_create_pod_should_execute_successfully(self, mocked_pod, 
mocked_pod_obj, do_xcom_push):
         """
         Asserts that a task is deferred and the KubernetesCreatePodTrigger 
will be fired
         when the KubernetesPodOperator is executed in deferrable mode when 
deferrable=True.
+
+        pod name and namespace are *always* pushed; do_xcom_push only controls 
xcom sidecar
         """
 
         k = KubernetesPodOperator(
@@ -1491,10 +1494,23 @@ class TestKubernetesPodOperatorAsync:
             in_cluster=True,
             get_logs=True,
             deferrable=True,
+            do_xcom_push=do_xcom_push,
         )
         k.config_file_in_dict_representation = {"a": "b"}
+
+        mocked_pod.return_value.metadata.name = TEST_NAME
+        mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE
+
+        context = create_context(k)
+        ti_mock = MagicMock()
+        context["ti"] = ti_mock
+
         with pytest.raises(TaskDeferred) as exc:
-            k.execute(create_context(k))
+            k.execute(context)
+
+        assert ti_mock.xcom_push.call_count == 2
+        ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
+        ti_mock.xcom_push.assert_any_call(key="pod_namespace", 
value=TEST_NAMESPACE)
         assert isinstance(exc.value.trigger, KubernetesPodTrigger)
 
     @patch(KUB_OP_PATH.format("cleanup"))
@@ -1655,34 +1671,6 @@ class TestKubernetesPodOperatorAsync:
             },
         )
 
-    @pytest.mark.parametrize("do_xcom_push", [True, False])
-    @patch(KUB_OP_PATH.format("post_complete_action"))
-    @patch(KUB_OP_PATH.format("extract_xcom"))
-    @patch(POD_MANAGER_CLASS)
-    @patch(HOOK_CLASS)
-    def test_async_push_xcom_check_xcom_values_should_execute_successfully(
-        self, mocked_hook, mock_manager, mock_extract_xcom, 
post_complete_action, do_xcom_push
-    ):
-        """pod name and namespace are *always* pushed; do_xcom_push only 
controls xcom sidecar"""
-
-        mocked_hook.return_value.get_pod.return_value = k8s.V1Pod(
-            metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
-        )
-        mock_manager.return_value.await_pod_completion.return_value = {}
-        mock_extract_xcom.return_value = "{}"
-        k = KubernetesPodOperator(
-            task_id="task",
-            do_xcom_push=do_xcom_push,
-            deferrable=True,
-        )
-
-        pod = self.run_pod_async(k)
-
-        pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task", 
key="pod_name")
-        pod_namespace = XCom.get_one(run_id=self.dag_run.run_id, 
task_id="task", key="pod_namespace")
-        assert pod_name == pod.metadata.name
-        assert pod_namespace == pod.metadata.namespace
-
     @pytest.mark.parametrize("get_logs", [True, False])
     @patch(KUB_OP_PATH.format("post_complete_action"))
     @patch(KUB_OP_PATH.format("write_logs"))
@@ -1780,8 +1768,6 @@ def 
test_async_kpo_wait_termination_before_cleanup_on_success(
         succeeded_state,
     ]
 
-    ti_mock = MagicMock()
-
     success_event = {
         "status": "success",
         "message": TEST_SUCCESS_MESSAGE,
@@ -1790,16 +1776,11 @@ def 
test_async_kpo_wait_termination_before_cleanup_on_success(
     }
 
     k = KubernetesPodOperator(task_id="task", deferrable=True, 
do_xcom_push=do_xcom_push)
-    k.execute_complete({"ti": ti_mock}, success_event)
+    k.execute_complete({}, success_event)
 
     # check if it gets the pod
     mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME, 
TEST_NAMESPACE)
 
-    # check if it pushes the xcom
-    assert ti_mock.xcom_push.call_count == 2
-    ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
-    ti_mock.xcom_push.assert_any_call(key="pod_namespace", 
value=TEST_NAMESPACE)
-
     # assert that the xcom are extracted/not extracted
     if do_xcom_push:
         mock_extract_xcom.assert_called_once()

Reply via email to