This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 848c69a194 Refresh GKE OAuth2 tokens (#32673)
848c69a194 is described below

commit 848c69a194c03ed3a5badc909e26b5c1bda03050
Author: Freddy Demiane <[email protected]>
AuthorDate: Thu Jul 20 16:16:32 2023 +0200

    Refresh GKE OAuth2 tokens (#32673)
    
    * Refresh token for sync mode
---
 .../google/cloud/hooks/kubernetes_engine.py        |  4 ++
 .../google/cloud/hooks/test_kubernetes_engine.py   | 46 ++++++++++++++++++++++
 2 files changed, 50 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py 
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index d8b1d92ffa..00df2b9b28 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -387,8 +387,12 @@ class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
 
     def get_conn(self) -> client.ApiClient:
         configuration = self._get_config()
+        configuration.refresh_api_key_hook = self._refresh_api_key_hook
         return client.ApiClient(configuration)
 
+    def _refresh_api_key_hook(self, configuration: 
client.configuration.Configuration):
+        configuration.api_key = {"authorization": 
self._get_token(self.get_credentials())}
+
     def _get_config(self) -> client.configuration.Configuration:
         configuration = client.Configuration(
             host=self._cluster_url,
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py 
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index 53963c3a32..a3f0c59c12 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -30,6 +30,7 @@ from airflow.providers.google.cloud.hooks.kubernetes_engine 
import (
     GKEAsyncHook,
     GKEHook,
     GKEPodAsyncHook,
+    GKEPodHook,
 )
 from airflow.providers.google.common.consts import CLIENT_INFO
 from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
@@ -397,3 +398,48 @@ class TestGKEAsyncHook:
         mock_async_gke_cluster_client.get_operation.assert_called_once_with(
             name=operation_path,
         )
+
+
+class TestGKEPodHook:
+    def setup_method(self):
+        with mock.patch(
+            BASE_STRING.format("GoogleBaseHook.__init__"), 
new=mock_base_gcp_hook_default_project_id
+        ):
+            self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None, 
cluster_url=None)
+        self.gke_hook._client = mock.Mock()
+
+        def refresh_token(request):
+            self.credentials.token = "New"
+
+        self.credentials = mock.MagicMock()
+        self.credentials.token = "Old"
+        self.credentials.expired = False
+        self.credentials.refresh = refresh_token
+
+    @mock.patch(GKE_STRING.format("google_requests.Request"))
+    def test_get_connection_update_hook_with_invalid_token(self, mock_request):
+        self.gke_hook._get_config = self._get_config
+        self.gke_hook.get_credentials = self._get_credentials
+        self.gke_hook.get_credentials().expired = True
+        the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+        the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+        assert self.gke_hook.get_credentials().token == "New"
+
+    @mock.patch(GKE_STRING.format("google_requests.Request"))
+    def test_get_connection_update_hook_with_valid_token(self, mock_request):
+        self.gke_hook._get_config = self._get_config
+        self.gke_hook.get_credentials = self._get_credentials
+        self.gke_hook.get_credentials().expired = False
+        the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+        the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+        assert self.gke_hook.get_credentials().token == "Old"
+
+    def _get_config(self):
+        return kubernetes.client.configuration.Configuration()
+
+    def _get_credentials(self):
+        return self.credentials

Reply via email to