This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 810fb5f2a8 feat(GKEPodAsyncHook): use async credentials token implementation (#37486) 810fb5f2a8 is described below commit 810fb5f2a8dbf624048d5f1a12398114c5fc7953 Author: Cedrik Neumann <7921017+m1rac...@users.noreply.github.com> AuthorDate: Thu Feb 22 01:42:05 2024 +0100 feat(GKEPodAsyncHook): use async credentials token implementation (#37486) We utilize the existing implementation of `_CredentialsToken` by using the async hook's `get_token` method. This implementation allows us to leverage several features of the Google connection from `Keyfile Path` or `Keyfile JSON` (see #37081) to impersonation chain on hook or connection level. We therefore do not need to rely on the async hook's `service_file_as_context` method, which does not support impersonation chain. With this change we effectively gain support for impersonation chain in GKEStartPodOperator in deferrable mode. --- .../google/cloud/hooks/kubernetes_engine.py | 83 ++++++++++------------ .../google/cloud/hooks/test_kubernetes_engine.py | 52 +++++--------- 2 files changed, 58 insertions(+), 77 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 4fa97372f6..4eef792d61 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -25,7 +25,6 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence from deprecated import deprecated -from gcloud.aio.auth import Token from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.auth.transport import requests as google_requests @@ -54,6 +53,7 @@ from airflow.providers.google.common.hooks.base_google import ( if TYPE_CHECKING: import google.auth.credentials + from gcloud.aio.auth import Token from google.api_core.retry import Retry from kubernetes_asyncio.client.models import V1Pod @@ -709,15 +709,14 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook): :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] - async with Token(scopes=self.scopes, service_file=service_file) as token: - async with self.get_conn(token) as connection: - v1_api = async_client.CoreV1Api(connection) - pod: V1Pod = await v1_api.read_namespaced_pod( - name=name, - namespace=namespace, - ) - return pod + token = await self.get_token() + async with self.get_conn(token) as connection: + v1_api = async_client.CoreV1Api(connection) + pod: V1Pod = await v1_api.read_namespaced_pod( + name=name, + namespace=namespace, + ) + return pod async def delete_pod(self, name: str, namespace: str): """Delete a pod. @@ -725,21 +724,19 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook): :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] - async with Token(scopes=self.scopes, service_file=service_file) as token, self.get_conn( - token - ) as connection: - try: - v1_api = async_client.CoreV1Api(connection) - await v1_api.delete_namespaced_pod( - name=name, - namespace=namespace, - body=client.V1DeleteOptions(), - ) - except async_client.ApiException as e: - # If the pod is already deleted - if e.status != 404: - raise + token = await self.get_token() + async with self.get_conn(token) as connection: + try: + v1_api = async_client.CoreV1Api(connection) + await v1_api.delete_namespaced_pod( + name=name, + namespace=namespace, + body=client.V1DeleteOptions(), + ) + except async_client.ApiException as e: + # If the pod is already deleted + if e.status != 404: + raise async def read_logs(self, name: str, namespace: str): """Read logs inside the pod while starting containers inside. @@ -752,22 +749,20 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook): :param name: Name of the pod. :param namespace: Name of the pod's namespace. """ - with await self.service_file_as_context() as service_file: # type: ignore[attr-defined] - async with Token(scopes=self.scopes, service_file=service_file) as token, self.get_conn( - token - ) as connection: - try: - v1_api = async_client.CoreV1Api(connection) - logs = await v1_api.read_namespaced_pod_log( - name=name, - namespace=namespace, - follow=False, - timestamps=True, - ) - logs = logs.splitlines() - for line in logs: - self.log.info("Container logs from %s", line) - return logs - except HTTPError: - self.log.exception("There was an error reading the kubernetes API.") - raise + token = await self.get_token() + async with self.get_conn(token) as connection: + try: + v1_api = async_client.CoreV1Api(connection) + logs = await v1_api.read_namespaced_pod_log( + name=name, + namespace=namespace, + follow=False, + timestamps=True, + ) + logs = logs.splitlines() + for line in logs: + self.log.info("Container logs from %s", line) + return logs + except HTTPError: + self.log.exception("There was an error reading the kubernetes API.") + raise diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index f00b7d4efd..3d9ff1ce97 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -507,47 +507,38 @@ class TestGKEPodAsyncHook: ) @pytest.mark.asyncio - @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", None)) - @mock.patch(GKE_STRING.format("Token")) + @mock.patch(BASE_STRING.format("_CredentialsToken")) @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn")) @mock.patch(GKE_STRING.format("async_client.CoreV1Api.read_namespaced_pod")) - async def test_get_pod( - self, read_namespace_pod_mock, get_conn_mock, mock_token, async_hook, mock_service_file - ): - async_hook.service_file_as_context = mock.AsyncMock() - async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file + async def test_get_pod(self, read_namespace_pod_mock, get_conn_mock, mock_token, async_hook): + async_hook.get_token = mock.AsyncMock() + async_hook.get_token.return_value = mock_token self.make_mock_awaitable(read_namespace_pod_mock) await async_hook.get_pod(name=POD_NAME, namespace=POD_NAMESPACE) - mock_token.assert_called_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"], service_file=mock_service_file - ) - get_conn_mock.assert_called_once() + + async_hook.get_token.assert_called_once() + get_conn_mock.assert_called_once_with(mock_token) read_namespace_pod_mock.assert_called_with( name=POD_NAME, namespace=POD_NAMESPACE, ) @pytest.mark.asyncio - @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", None)) - @mock.patch(GKE_STRING.format("Token")) + @mock.patch(BASE_STRING.format("_CredentialsToken")) @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn")) @mock.patch(GKE_STRING.format("async_client.CoreV1Api.delete_namespaced_pod")) - async def test_delete_pod( - self, delete_namespaced_pod, get_conn_mock, mock_token, async_hook, mock_service_file - ): - async_hook.service_file_as_context = mock.AsyncMock() - async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file + async def test_delete_pod(self, delete_namespaced_pod, get_conn_mock, mock_token, async_hook): + async_hook.get_token = mock.AsyncMock() + async_hook.get_token.return_value = mock_token self.make_mock_awaitable(delete_namespaced_pod) await async_hook.delete_pod(name=POD_NAME, namespace=POD_NAMESPACE) - mock_token.assert_called_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"], service_file=mock_service_file - ) - get_conn_mock.assert_called_once() + async_hook.get_token.assert_called_once() + get_conn_mock.assert_called_once_with(mock_token) delete_namespaced_pod.assert_called_with( name=POD_NAME, namespace=POD_NAMESPACE, @@ -555,24 +546,19 @@ class TestGKEPodAsyncHook: ) @pytest.mark.asyncio - @pytest.mark.parametrize("mock_service_file", ("/tmp/service_file.json", None)) - @mock.patch(GKE_STRING.format("Token")) + @mock.patch(BASE_STRING.format("_CredentialsToken")) @mock.patch(GKE_STRING.format("GKEPodAsyncHook.get_conn")) @mock.patch(GKE_STRING.format("async_client.CoreV1Api.read_namespaced_pod_log")) - async def test_read_logs( - self, read_namespaced_pod_log, get_conn_mock, mock_token, async_hook, mock_service_file, caplog - ): - async_hook.service_file_as_context = mock.AsyncMock() - async_hook.service_file_as_context.return_value.__enter__.return_value = mock_service_file + async def test_read_logs(self, read_namespaced_pod_log, get_conn_mock, mock_token, async_hook, caplog): + async_hook.get_token = mock.AsyncMock() + async_hook.get_token.return_value = mock_token self.make_mock_awaitable(read_namespaced_pod_log, result="Test string #1\nTest string #2\n") await async_hook.read_logs(name=POD_NAME, namespace=POD_NAMESPACE) - mock_token.assert_called_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"], service_file=mock_service_file - ) - get_conn_mock.assert_called_once() + async_hook.get_token.assert_called_once() + get_conn_mock.assert_called_once_with(mock_token) read_namespaced_pod_log.assert_called_with( name=POD_NAME, namespace=POD_NAMESPACE,