This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 43b48a9ac1 Refresh properties on KubernetesPodOperator when k8s fails 
due to token expiration (#39325)
43b48a9ac1 is described below

commit 43b48a9ac1737b9dbafa706caa6f266399753521
Author: paolomoriello <[email protected]>
AuthorDate: Wed May 1 11:03:57 2024 +0200

    Refresh properties on KubernetesPodOperator when k8s fails due to token 
expiration (#39325)
    
    * Refresh KubernetesPodOperator properties when credentials expire
    
    * Add test
    
    * Fix linting errors
    
    * Implement review changes and add more tests
    
    * Improve tests
    
    * Fix linting errors
    
    * Fix linting errors
    
    * Move exception check to utils
    
    ---------
    
    Co-authored-by: pmoriello <[email protected]>
---
 airflow/providers/cncf/kubernetes/operators/pod.py | 28 ++++++++--
 .../providers/cncf/kubernetes/utils/pod_manager.py |  4 ++
 .../cncf/kubernetes/operators/test_pod.py          | 63 ++++++++++++++++++++--
 3 files changed, 88 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 0b387352ae..4a3c7312ed 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -33,6 +33,7 @@ from functools import cached_property
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
 
 import kubernetes
+import tenacity
 from deprecated import deprecated
 from kubernetes.client import CoreV1Api, V1Pod, models as k8s
 from kubernetes.stream import stream
@@ -77,6 +78,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager 
import (
     PodNotFoundException,
     PodOperatorHookProtocol,
     PodPhase,
+    check_exception_is_kubernetes_api_unauthorized,
     container_is_succeeded,
     get_container_termination_message,
 )
@@ -618,9 +620,7 @@ class KubernetesPodOperator(BaseOperator):
             if not self.get_logs or (
                 self.container_logs is not True and self.base_container_name 
not in self.container_logs
             ):
-                self.pod_manager.await_container_completion(
-                    pod=self.pod, container_name=self.base_container_name
-                )
+                self.await_container_completion(pod=self.pod, 
container_name=self.base_container_name)
             if self.callbacks:
                 self.callbacks.on_pod_completion(
                     pod=self.find_pod(self.pod.metadata.namespace, 
context=context),
@@ -647,6 +647,28 @@ class KubernetesPodOperator(BaseOperator):
         if self.do_xcom_push:
             return result
 
+    @tenacity.retry(
+        stop=tenacity.stop_after_attempt(3),
+        wait=tenacity.wait_exponential(max=15),
+        retry=tenacity.retry_if_exception(lambda exc: 
check_exception_is_kubernetes_api_unauthorized(exc)),
+        reraise=True,
+    )
+    def await_container_completion(self, pod: k8s.V1Pod, container_name: str):
+        try:
+            self.pod_manager.await_container_completion(pod=pod, 
container_name=container_name)
+        except kubernetes.client.exceptions.ApiException as exc:
+            if exc.status and str(exc.status) == "401":
+                self.log.warning(
+                    "Failed to check container status due to permission error. 
Refreshing credentials and retrying."
+                )
+                self._refresh_cached_properties()
+            raise exc
+
+    def _refresh_cached_properties(self):
+        del self.hook
+        del self.client
+        del self.pod_manager
+
     def execute_async(self, context: Context):
         self.pod_request_obj = self.build_pod_request_obj(context)
         self.pod = self.get_or_create_pod(  # must set `self.pod` for `on_kill`
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index c179cb17dc..5668462ad6 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -190,6 +190,10 @@ def get_container_termination_message(pod: V1Pod, 
container_name: str):
         return container_status.state.terminated.message if container_status 
else None
 
 
+def check_exception_is_kubernetes_api_unauthorized(exc: BaseException):
+    return isinstance(exc, ApiException) and exc.status and str(exc.status) == 
"401"
+
+
 class PodLaunchTimeoutException(AirflowException):
     """When pod does not leave the ``Pending`` phase within specified 
timeout."""
 
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index db10bff330..f9277bab01 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -25,6 +25,7 @@ from unittest.mock import MagicMock, patch
 import pendulum
 import pytest
 from kubernetes.client import ApiClient, V1Pod, V1PodSecurityContext, 
V1PodStatus, models as k8s
+from kubernetes.client.rest import ApiException
 from urllib3 import HTTPResponse
 
 from airflow.exceptions import AirflowException, AirflowSkipException, 
TaskDeferred
@@ -38,10 +39,7 @@ from airflow.providers.cncf.kubernetes.operators.pod import (
 )
 from airflow.providers.cncf.kubernetes.secret import Secret
 from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
-from airflow.providers.cncf.kubernetes.utils.pod_manager import (
-    PodLoggingStatus,
-    PodPhase,
-)
+from airflow.providers.cncf.kubernetes.utils.pod_manager import 
PodLoggingStatus, PodPhase
 from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
 from airflow.utils import timezone
 from airflow.utils.session import create_session
@@ -1614,6 +1612,63 @@ class TestKubernetesPodOperator:
             "pod": remote_pod_mock,
         }
 
+    @patch(f"{POD_MANAGER_CLASS}.await_container_completion")
+    def test_await_container_completion_refreshes_properties_on_exception(
+        self, mock_await_container_completion
+    ):
+        container_name = "base"
+        k = KubernetesPodOperator(
+            task_id="task",
+        )
+        pod = self.run_pod(k)
+        client, hook, pod_manager = k.client, k.hook, k.pod_manager
+
+        # no exception doesn't update properties
+        k.await_container_completion(pod, container_name=container_name)
+        assert client == k.client
+        assert hook == k.hook
+        assert pod_manager == k.pod_manager
+
+        # exception refreshes properties
+        mock_await_container_completion.side_effect = 
[ApiException(status=401), mock.DEFAULT]
+        k.await_container_completion(pod, container_name=container_name)
+        mock_await_container_completion.assert_has_calls(
+            [mock.call(pod=pod, container_name=container_name)] * 3
+        )
+        assert client != k.client
+        assert hook != k.hook
+        assert pod_manager != k.pod_manager
+
+    @pytest.mark.parametrize(
+        "side_effect, exception_type, expect_exc",
+        [
+            ([ApiException(401), mock.DEFAULT], ApiException, True),  # works 
after one 401
+            ([ApiException(401)] * 10, ApiException, False),  # exc after 3 
retries on 401
+            ([ApiException(402)], ApiException, False),  # exc on non-401
+            ([ApiException(500)], ApiException, False),  # exc on non-401
+            ([Exception], Exception, False),  # exc on different exception
+        ],
+    )
+    @patch(f"{POD_MANAGER_CLASS}.await_container_completion")
+    def test_await_container_completion_retries_on_specific_exception(
+        self, mock_await_container_completion, side_effect, exception_type, 
expect_exc
+    ):
+        container_name = "base"
+        k = KubernetesPodOperator(
+            task_id="task",
+        )
+        pod = self.run_pod(k)
+        mock_await_container_completion.side_effect = side_effect
+        if expect_exc:
+            k.await_container_completion(pod, container_name=container_name)
+        else:
+            with pytest.raises(exception_type):
+                k.await_container_completion(pod, 
container_name=container_name)
+        expected_call_count = min(len(side_effect), 3)  # retry max 3 times
+        mock_await_container_completion.assert_has_calls(
+            [mock.call(pod=pod, container_name=container_name)] * 
expected_call_count
+        )
+
 
 class TestSuppress:
     def test__suppress(self, caplog):

Reply via email to