This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 848c69a194 Refresh GKE OAuth2 tokens (#32673)
848c69a194 is described below
commit 848c69a194c03ed3a5badc909e26b5c1bda03050
Author: Freddy Demiane <[email protected]>
AuthorDate: Thu Jul 20 16:16:32 2023 +0200
Refresh GKE OAuth2 tokens (#32673)
* Refresh token for sync mode
---
.../google/cloud/hooks/kubernetes_engine.py | 4 ++
.../google/cloud/hooks/test_kubernetes_engine.py | 46 ++++++++++++++++++++++
2 files changed, 50 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index d8b1d92ffa..00df2b9b28 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -387,8 +387,12 @@ class GKEPodHook(GoogleBaseHook, PodOperatorHookProtocol):
def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
+ configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)
+ def _refresh_api_key_hook(self, configuration:
client.configuration.Configuration):
+ configuration.api_key = {"authorization":
self._get_token(self.get_credentials())}
+
def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index 53963c3a32..a3f0c59c12 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -30,6 +30,7 @@ from airflow.providers.google.cloud.hooks.kubernetes_engine
import (
GKEAsyncHook,
GKEHook,
GKEPodAsyncHook,
+ GKEPodHook,
)
from airflow.providers.google.common.consts import CLIENT_INFO
from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
@@ -397,3 +398,48 @@ class TestGKEAsyncHook:
mock_async_gke_cluster_client.get_operation.assert_called_once_with(
name=operation_path,
)
+
+
+class TestGKEPodHook:
+ def setup_method(self):
+ with mock.patch(
+ BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_base_gcp_hook_default_project_id
+ ):
+ self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None,
cluster_url=None)
+ self.gke_hook._client = mock.Mock()
+
+ def refresh_token(request):
+ self.credentials.token = "New"
+
+ self.credentials = mock.MagicMock()
+ self.credentials.token = "Old"
+ self.credentials.expired = False
+ self.credentials.refresh = refresh_token
+
+ @mock.patch(GKE_STRING.format("google_requests.Request"))
+ def test_get_connection_update_hook_with_invalid_token(self, mock_request):
+ self.gke_hook._get_config = self._get_config
+ self.gke_hook.get_credentials = self._get_credentials
+ self.gke_hook.get_credentials().expired = True
+ the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+ the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+ assert self.gke_hook.get_credentials().token == "New"
+
+ @mock.patch(GKE_STRING.format("google_requests.Request"))
+ def test_get_connection_update_hook_with_valid_token(self, mock_request):
+ self.gke_hook._get_config = self._get_config
+ self.gke_hook.get_credentials = self._get_credentials
+ self.gke_hook.get_credentials().expired = False
+ the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()
+
+ the_client.configuration.refresh_api_key_hook(the_client.configuration)
+
+ assert self.gke_hook.get_credentials().token == "Old"
+
+ def _get_config(self):
+ return kubernetes.client.configuration.Configuration()
+
+ def _get_credentials(self):
+ return self.credentials