This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b435b8edef Push to xcom before `KubernetesPodOperator` deferral
(#34209)
b435b8edef is described below
commit b435b8edefd181fa85e6cc6b2b822d113f562e27
Author: Victor Chiapaikeo <[email protected]>
AuthorDate: Sat Sep 9 14:25:34 2023 -0400
Push to xcom before `KubernetesPodOperator` deferral (#34209)
---
airflow/providers/cncf/kubernetes/operators/pod.py | 8 +--
.../cncf/kubernetes/operators/test_pod.py | 57 ++++++++--------------
2 files changed, 23 insertions(+), 42 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 5f7269718f..5ec9e2decb 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -630,6 +630,10 @@ class KubernetesPodOperator(BaseOperator):
pod_request_obj=self.pod_request_obj,
context=context,
)
+ ti = context["ti"]
+ ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
+ ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
+
self.invoke_defer_method()
def invoke_defer_method(self):
@@ -666,10 +670,6 @@ class KubernetesPodOperator(BaseOperator):
self.write_logs(pod)
raise AirflowException(event["message"])
elif event["status"] == "success":
- ti = context["ti"]
- ti.xcom_push(key="pod_name", value=pod.metadata.name)
- ti.xcom_push(key="pod_namespace", value=pod.metadata.namespace)
-
# fetch some logs when pod is executed successfully
if self.get_logs:
self.write_logs(pod)
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 3ada6218a5..fde4e60159 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1471,12 +1471,15 @@ class TestKubernetesPodOperatorAsync:
)
return remote_pod_mock
+ @pytest.mark.parametrize("do_xcom_push", [True, False])
@patch(KUB_OP_PATH.format("build_pod_request_obj"))
@patch(KUB_OP_PATH.format("get_or_create_pod"))
- def test_async_create_pod_should_execute_successfully(self, mocked_pod,
mocked_pod_obj):
+ def test_async_create_pod_should_execute_successfully(self, mocked_pod,
mocked_pod_obj, do_xcom_push):
"""
Asserts that a task is deferred and the KubernetesCreatePodTrigger
will be fired
when the KubernetesPodOperator is executed in deferrable mode when
deferrable=True.
+
+ pod name and namespace are *always* pushed; do_xcom_push only controls
xcom sidecar
"""
k = KubernetesPodOperator(
@@ -1491,10 +1494,23 @@ class TestKubernetesPodOperatorAsync:
in_cluster=True,
get_logs=True,
deferrable=True,
+ do_xcom_push=do_xcom_push,
)
k.config_file_in_dict_representation = {"a": "b"}
+
+ mocked_pod.return_value.metadata.name = TEST_NAME
+ mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE
+
+ context = create_context(k)
+ ti_mock = MagicMock()
+ context["ti"] = ti_mock
+
with pytest.raises(TaskDeferred) as exc:
- k.execute(create_context(k))
+ k.execute(context)
+
+ assert ti_mock.xcom_push.call_count == 2
+ ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
+ ti_mock.xcom_push.assert_any_call(key="pod_namespace",
value=TEST_NAMESPACE)
assert isinstance(exc.value.trigger, KubernetesPodTrigger)
@patch(KUB_OP_PATH.format("cleanup"))
@@ -1655,34 +1671,6 @@ class TestKubernetesPodOperatorAsync:
},
)
- @pytest.mark.parametrize("do_xcom_push", [True, False])
- @patch(KUB_OP_PATH.format("post_complete_action"))
- @patch(KUB_OP_PATH.format("extract_xcom"))
- @patch(POD_MANAGER_CLASS)
- @patch(HOOK_CLASS)
- def test_async_push_xcom_check_xcom_values_should_execute_successfully(
- self, mocked_hook, mock_manager, mock_extract_xcom,
post_complete_action, do_xcom_push
- ):
- """pod name and namespace are *always* pushed; do_xcom_push only
controls xcom sidecar"""
-
- mocked_hook.return_value.get_pod.return_value = k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(name=TEST_NAME, namespace=TEST_NAMESPACE)
- )
- mock_manager.return_value.await_pod_completion.return_value = {}
- mock_extract_xcom.return_value = "{}"
- k = KubernetesPodOperator(
- task_id="task",
- do_xcom_push=do_xcom_push,
- deferrable=True,
- )
-
- pod = self.run_pod_async(k)
-
- pod_name = XCom.get_one(run_id=self.dag_run.run_id, task_id="task",
key="pod_name")
- pod_namespace = XCom.get_one(run_id=self.dag_run.run_id,
task_id="task", key="pod_namespace")
- assert pod_name == pod.metadata.name
- assert pod_namespace == pod.metadata.namespace
-
@pytest.mark.parametrize("get_logs", [True, False])
@patch(KUB_OP_PATH.format("post_complete_action"))
@patch(KUB_OP_PATH.format("write_logs"))
@@ -1780,8 +1768,6 @@ def
test_async_kpo_wait_termination_before_cleanup_on_success(
succeeded_state,
]
- ti_mock = MagicMock()
-
success_event = {
"status": "success",
"message": TEST_SUCCESS_MESSAGE,
@@ -1790,16 +1776,11 @@ def
test_async_kpo_wait_termination_before_cleanup_on_success(
}
k = KubernetesPodOperator(task_id="task", deferrable=True,
do_xcom_push=do_xcom_push)
- k.execute_complete({"ti": ti_mock}, success_event)
+ k.execute_complete({}, success_event)
# check if it gets the pod
mocked_hook.return_value.get_pod.assert_called_once_with(TEST_NAME,
TEST_NAMESPACE)
- # check if it pushes the xcom
- assert ti_mock.xcom_push.call_count == 2
- ti_mock.xcom_push.assert_any_call(key="pod_name", value=TEST_NAME)
- ti_mock.xcom_push.assert_any_call(key="pod_namespace",
value=TEST_NAMESPACE)
-
# assert that the xcom are extracted/not extracted
if do_xcom_push:
mock_extract_xcom.assert_called_once()