This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 032d27640b Refresh properties on KubernetesPodOperator on token 
expiration also when logging (#39789)
032d27640b is described below

commit 032d27640b5124e8fd85ba93042b50989881895a
Author: paolomoriello <[email protected]>
AuthorDate: Thu May 30 15:18:02 2024 +0200

    Refresh properties on KubernetesPodOperator on token expiration also when 
logging (#39789)
    
    * Refresh KubernetesPodOperator properties when credentials expire also 
when logs are enabled
    
    * Linting
    
    * Rename function
    
    ---------
    
    Co-authored-by: pmoriello <[email protected]>
---
 airflow/providers/cncf/kubernetes/operators/pod.py | 24 +++++++--------
 .../cncf/kubernetes/operators/test_pod.py          | 34 +++++++++++++---------
 2 files changed, 32 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index ef5366d6c4..67f63f6678 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -612,16 +612,7 @@ class KubernetesPodOperator(BaseOperator):
                     mode=ExecutionMode.SYNC,
                 )
 
-            if self.get_logs:
-                self.pod_manager.fetch_requested_container_logs(
-                    pod=self.pod,
-                    containers=self.container_logs,
-                    follow_logs=True,
-                )
-            if not self.get_logs or (
-                self.container_logs is not True and self.base_container_name 
not in self.container_logs
-            ):
-                self.await_container_completion(pod=self.pod, 
container_name=self.base_container_name)
+            self.await_pod_completion(pod=self.pod)
             if self.callbacks:
                 self.callbacks.on_pod_completion(
                     pod=self.find_pod(self.pod.metadata.namespace, 
context=context),
@@ -654,9 +645,18 @@ class KubernetesPodOperator(BaseOperator):
         retry=tenacity.retry_if_exception(lambda exc: 
check_exception_is_kubernetes_api_unauthorized(exc)),
         reraise=True,
     )
-    def await_container_completion(self, pod: k8s.V1Pod, container_name: str):
+    def await_pod_completion(self, pod: k8s.V1Pod):
         try:
-            self.pod_manager.await_container_completion(pod=pod, 
container_name=container_name)
+            if self.get_logs:
+                self.pod_manager.fetch_requested_container_logs(
+                    pod=pod,
+                    containers=self.container_logs,
+                    follow_logs=True,
+                )
+            if not self.get_logs or (
+                self.container_logs is not True and self.base_container_name 
not in self.container_logs
+            ):
+                self.pod_manager.await_container_completion(pod=pod, 
container_name=self.base_container_name)
         except kubernetes.client.exceptions.ApiException as exc:
             if exc.status and str(exc.status) == "401":
                 self.log.warning(
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index f7a8e24d69..c1e0c29ff4 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -1621,29 +1621,35 @@ class TestKubernetesPodOperator:
             "pod": remote_pod_mock,
         }
 
+    @pytest.mark.parametrize("get_logs", [True, False])
+    @patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs")
     @patch(f"{POD_MANAGER_CLASS}.await_container_completion")
     def test_await_container_completion_refreshes_properties_on_exception(
-        self, mock_await_container_completion
+        self, mock_await_container_completion, fetch_requested_container_logs, 
get_logs
     ):
-        container_name = "base"
-        k = KubernetesPodOperator(
-            task_id="task",
-        )
+        k = KubernetesPodOperator(task_id="task", get_logs=get_logs)
         pod = self.run_pod(k)
         client, hook, pod_manager = k.client, k.hook, k.pod_manager
 
         # no exception doesn't update properties
-        k.await_container_completion(pod, container_name=container_name)
+        k.await_pod_completion(pod)
         assert client == k.client
         assert hook == k.hook
         assert pod_manager == k.pod_manager
 
         # exception refreshes properties
         mock_await_container_completion.side_effect = 
[ApiException(status=401), mock.DEFAULT]
-        k.await_container_completion(pod, container_name=container_name)
-        mock_await_container_completion.assert_has_calls(
-            [mock.call(pod=pod, container_name=container_name)] * 3
-        )
+        fetch_requested_container_logs.side_effect = 
[ApiException(status=401), mock.DEFAULT]
+        k.await_pod_completion(pod)
+
+        if get_logs:
+            fetch_requested_container_logs.assert_has_calls(
+                [mock.call(pod=pod, containers=k.container_logs, 
follow_logs=True)] * 3
+            )
+        else:
+            mock_await_container_completion.assert_has_calls(
+                [mock.call(pod=pod, container_name=k.base_container_name)] * 3
+            )
         assert client != k.client
         assert hook != k.hook
         assert pod_manager != k.pod_manager
@@ -1662,20 +1668,20 @@ class TestKubernetesPodOperator:
     def test_await_container_completion_retries_on_specific_exception(
         self, mock_await_container_completion, side_effect, exception_type, 
expect_exc
     ):
-        container_name = "base"
         k = KubernetesPodOperator(
             task_id="task",
+            get_logs=False,
         )
         pod = self.run_pod(k)
         mock_await_container_completion.side_effect = side_effect
         if expect_exc:
-            k.await_container_completion(pod, container_name=container_name)
+            k.await_pod_completion(pod)
         else:
             with pytest.raises(exception_type):
-                k.await_container_completion(pod, 
container_name=container_name)
+                k.await_pod_completion(pod)
         expected_call_count = min(len(side_effect), 3)  # retry max 3 times
         mock_await_container_completion.assert_has_calls(
-            [mock.call(pod=pod, container_name=container_name)] * 
expected_call_count
+            [mock.call(pod=pod, container_name=k.base_container_name)] * 
expected_call_count
         )
 
 

Reply via email to