This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 810fb5f2a8 feat(GKEPodAsyncHook): use async credentials token 
implementation (#37486)
810fb5f2a8 is described below

commit 810fb5f2a8dbf624048d5f1a12398114c5fc7953
Author: Cedrik Neumann <7921017+m1rac...@users.noreply.github.com>
AuthorDate: Thu Feb 22 01:42:05 2024 +0100

    feat(GKEPodAsyncHook): use async credentials token implementation (#37486)
    
    We utilize the existing implementation of `_CredentialsToken` by using
    the async hook's `get_token` method. This implementation allows us to
    leverage several features of the Google connection from `Keyfile Path`
    or `Keyfile JSON` (see #37081) to impersonation chain on hook or
    connection level. We therefore do not need to rely on the async hook's
    `service_file_as_context` method, which does not support impersonation
    chain.
    
    With this change we effectively gain support for impersonation chain in
    GKEStartPodOperator in deferrable mode.
---
 .../google/cloud/hooks/kubernetes_engine.py        | 83 ++++++++++------------
 .../google/cloud/hooks/test_kubernetes_engine.py   | 52 +++++---------
 2 files changed, 58 insertions(+), 77 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py 
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index 4fa97372f6..4eef792d61 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -25,7 +25,6 @@ from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
 from deprecated import deprecated
-from gcloud.aio.auth import Token
 from google.api_core.exceptions import NotFound
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
 from google.auth.transport import requests as google_requests
@@ -54,6 +53,7 @@ from airflow.providers.google.common.hooks.base_google import 
(
 
 if TYPE_CHECKING:
     import google.auth.credentials
+    from gcloud.aio.auth import Token
     from google.api_core.retry import Retry
     from kubernetes_asyncio.client.models import V1Pod
 
@@ -709,15 +709,14 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
         """
-        with await self.service_file_as_context() as service_file:  # type: 
ignore[attr-defined]
-            async with Token(scopes=self.scopes, service_file=service_file) as 
token:
-                async with self.get_conn(token) as connection:
-                    v1_api = async_client.CoreV1Api(connection)
-                    pod: V1Pod = await v1_api.read_namespaced_pod(
-                        name=name,
-                        namespace=namespace,
-                    )
-                return pod
+        token = await self.get_token()
+        async with self.get_conn(token) as connection:
+            v1_api = async_client.CoreV1Api(connection)
+            pod: V1Pod = await v1_api.read_namespaced_pod(
+                name=name,
+                namespace=namespace,
+            )
+            return pod
 
     async def delete_pod(self, name: str, namespace: str):
         """Delete a pod.
@@ -725,21 +724,19 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
         """
-        with await self.service_file_as_context() as service_file:  # type: 
ignore[attr-defined]
-            async with Token(scopes=self.scopes, service_file=service_file) as 
token, self.get_conn(
-                token
-            ) as connection:
-                try:
-                    v1_api = async_client.CoreV1Api(connection)
-                    await v1_api.delete_namespaced_pod(
-                        name=name,
-                        namespace=namespace,
-                        body=client.V1DeleteOptions(),
-                    )
-                except async_client.ApiException as e:
-                    # If the pod is already deleted
-                    if e.status != 404:
-                        raise
+        token = await self.get_token()
+        async with self.get_conn(token) as connection:
+            try:
+                v1_api = async_client.CoreV1Api(connection)
+                await v1_api.delete_namespaced_pod(
+                    name=name,
+                    namespace=namespace,
+                    body=client.V1DeleteOptions(),
+                )
+            except async_client.ApiException as e:
+                # If the pod is already deleted
+                if e.status != 404:
+                    raise
 
     async def read_logs(self, name: str, namespace: str):
         """Read logs inside the pod while starting containers inside.
@@ -752,22 +749,20 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
         """
-        with await self.service_file_as_context() as service_file:  # type: 
ignore[attr-defined]
-            async with Token(scopes=self.scopes, service_file=service_file) as 
token, self.get_conn(
-                token
-            ) as connection:
-                try:
-                    v1_api = async_client.CoreV1Api(connection)
-                    logs = await v1_api.read_namespaced_pod_log(
-                        name=name,
-                        namespace=namespace,
-                        follow=False,
-                        timestamps=True,
-                    )
-                    logs = logs.splitlines()
-                    for line in logs:
-                        self.log.info("Container logs from %s", line)
-                    return logs
-                except HTTPError:
-                    self.log.exception("There was an error reading the 
kubernetes API.")
-                    raise
+        token = await self.get_token()
+        async with self.get_conn(token) as connection:
+            try:
+                v1_api = async_client.CoreV1Api(connection)
+                logs = await v1_api.read_namespaced_pod_log(
+                    name=name,
+                    namespace=namespace,
+                    follow=False,
+                    timestamps=True,
+                )
+                logs = logs.splitlines()
+                for line in logs:
+                    self.log.info("Container logs from %s", line)
+                return logs
+            except HTTPError:
+                self.log.exception("There was an error reading the kubernetes 
API.")
+                raise
diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py 
b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
index f00b7d4efd..3d9ff1ce97 100644
--- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py
@@ -507,47 +507,38 @@ class TestGKEPodAsyncHook:
         )
 
     @pytest.mark.asyncio
-    @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", 
None))
-    @mock.patch(GKE_STRING.format("Token"))
+    @mock.patch(BASE_STRING.format("_CredentialsToken"))
     @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn"))
     
@mock.patch(GKE_STRING.format("async_client.CoreV1Api.read_namespaced_pod"))
-    async def test_get_pod(
-        self, read_namespace_pod_mock, get_conn_mock, mock_token, async_hook, 
mock_service_file
-    ):
-        async_hook.service_file_as_context = mock.AsyncMock()
-        async_hook.service_file_as_context.return_value.__enter__.return_value 
= mock_service_file
+    async def test_get_pod(self, read_namespace_pod_mock, get_conn_mock, 
mock_token, async_hook):
+        async_hook.get_token = mock.AsyncMock()
+        async_hook.get_token.return_value = mock_token
 
         self.make_mock_awaitable(read_namespace_pod_mock)
 
         await async_hook.get_pod(name=POD_NAME, namespace=POD_NAMESPACE)
-        mock_token.assert_called_with(
-            scopes=["https://www.googleapis.com/auth/cloud-platform";], 
service_file=mock_service_file
-        )
-        get_conn_mock.assert_called_once()
+
+        async_hook.get_token.assert_called_once()
+        get_conn_mock.assert_called_once_with(mock_token)
         read_namespace_pod_mock.assert_called_with(
             name=POD_NAME,
             namespace=POD_NAMESPACE,
         )
 
     @pytest.mark.asyncio
-    @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", 
None))
-    @mock.patch(GKE_STRING.format("Token"))
+    @mock.patch(BASE_STRING.format("_CredentialsToken"))
     @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn"))
     
@mock.patch(GKE_STRING.format("async_client.CoreV1Api.delete_namespaced_pod"))
-    async def test_delete_pod(
-        self, delete_namespaced_pod, get_conn_mock, mock_token, async_hook, 
mock_service_file
-    ):
-        async_hook.service_file_as_context = mock.AsyncMock()
-        async_hook.service_file_as_context.return_value.__enter__.return_value 
= mock_service_file
+    async def test_delete_pod(self, delete_namespaced_pod, get_conn_mock, 
mock_token, async_hook):
+        async_hook.get_token = mock.AsyncMock()
+        async_hook.get_token.return_value = mock_token
 
         self.make_mock_awaitable(delete_namespaced_pod)
 
         await async_hook.delete_pod(name=POD_NAME, namespace=POD_NAMESPACE)
 
-        mock_token.assert_called_with(
-            scopes=["https://www.googleapis.com/auth/cloud-platform";], 
service_file=mock_service_file
-        )
-        get_conn_mock.assert_called_once()
+        async_hook.get_token.assert_called_once()
+        get_conn_mock.assert_called_once_with(mock_token)
         delete_namespaced_pod.assert_called_with(
             name=POD_NAME,
             namespace=POD_NAMESPACE,
@@ -555,24 +546,19 @@ class TestGKEPodAsyncHook:
         )
 
     @pytest.mark.asyncio
-    @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", 
None))
-    @mock.patch(GKE_STRING.format("Token"))
+    @mock.patch(BASE_STRING.format("_CredentialsToken"))
     @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn"))
     
@mock.patch(GKE_STRING.format("async_client.CoreV1Api.read_namespaced_pod_log"))
-    async def test_read_logs(
-        self, read_namespaced_pod_log, get_conn_mock, mock_token, async_hook, 
mock_service_file, caplog
-    ):
-        async_hook.service_file_as_context = mock.AsyncMock()
-        async_hook.service_file_as_context.return_value.__enter__.return_value 
= mock_service_file
+    async def test_read_logs(self, read_namespaced_pod_log, get_conn_mock, 
mock_token, async_hook, caplog):
+        async_hook.get_token = mock.AsyncMock()
+        async_hook.get_token.return_value = mock_token
 
         self.make_mock_awaitable(read_namespaced_pod_log, result="Test string 
#1\nTest string #2\n")
 
         await async_hook.read_logs(name=POD_NAME, namespace=POD_NAMESPACE)
 
-        mock_token.assert_called_with(
-            scopes=["https://www.googleapis.com/auth/cloud-platform";], 
service_file=mock_service_file
-        )
-        get_conn_mock.assert_called_once()
+        async_hook.get_token.assert_called_once()
+        get_conn_mock.assert_called_once_with(mock_token)
         read_namespaced_pod_log.assert_called_with(
             name=POD_NAME,
             namespace=POD_NAMESPACE,

Reply via email to