This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ab43d5541 Move KubernetesPodTrigger hook to a cached property (#36290)
5ab43d5541 is described below

commit 5ab43d5541a68c5c90fe849f19e344bcdeddd44f
Author: Hussein Awala <[email protected]>
AuthorDate: Tue Dec 19 09:11:52 2023 +0100

    Move KubernetesPodTrigger hook to a cached property (#36290)
---
 airflow/providers/cncf/kubernetes/triggers/pod.py  | 28 +++++++-------
 .../google/cloud/triggers/kubernetes_engine.py     |  4 +-
 .../providers/cncf/kubernetes/triggers/test_pod.py | 44 +++++++++++-----------
 .../cloud/triggers/test_kubernetes_engine.py       | 32 ++++++++--------
 4 files changed, 56 insertions(+), 52 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py 
b/airflow/providers/cncf/kubernetes/triggers/pod.py
index b7f0348b66..3dd9eb173c 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -22,6 +22,7 @@ import traceback
 import warnings
 from asyncio import CancelledError
 from enum import Enum
+from functools import cached_property
 from typing import TYPE_CHECKING, Any, AsyncIterator
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -116,7 +117,6 @@ class KubernetesPodTrigger(BaseTrigger):
             self.on_finish_action = OnFinishAction(on_finish_action)
             self.should_delete_pod = self.on_finish_action == 
OnFinishAction.DELETE_POD
 
-        self._hook: AsyncKubernetesHook | None = None
         self._since_time = None
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -142,11 +142,10 @@ class KubernetesPodTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:  # type: 
ignore[override]
         """Get current pod status and yield a TriggerEvent."""
-        hook = self._get_async_hook()
         self.log.info("Checking pod %r in namespace %r.", self.pod_name, 
self.pod_namespace)
         try:
             while True:
-                pod = await hook.get_pod(
+                pod = await self.hook.get_pod(
                     name=self.pod_name,
                     namespace=self.pod_namespace,
                 )
@@ -206,13 +205,13 @@ class KubernetesPodTrigger(BaseTrigger):
             # That means that task was marked as failed
             if self.get_logs:
                 self.log.info("Outputting container logs...")
-                await self._get_async_hook().read_logs(
+                await self.hook.read_logs(
                     name=self.pod_name,
                     namespace=self.pod_namespace,
                 )
             if self.on_finish_action == OnFinishAction.DELETE_POD:
                 self.log.info("Deleting pod...")
-                await self._get_async_hook().delete_pod(
+                await self.hook.delete_pod(
                     name=self.pod_name,
                     namespace=self.pod_namespace,
                 )
@@ -237,14 +236,17 @@ class KubernetesPodTrigger(BaseTrigger):
             )
 
     def _get_async_hook(self) -> AsyncKubernetesHook:
-        if self._hook is None:
-            self._hook = AsyncKubernetesHook(
-                conn_id=self.kubernetes_conn_id,
-                in_cluster=self.in_cluster,
-                config_file=self.config_file,
-                cluster_context=self.cluster_context,
-            )
-        return self._hook
+        # TODO: Remove this method when the min version of kubernetes provider 
is 7.12.0 in Google provider.
+        return AsyncKubernetesHook(
+            conn_id=self.kubernetes_conn_id,
+            in_cluster=self.in_cluster,
+            config_file=self.config_file,
+            cluster_context=self.cluster_context,
+        )
+
+    @cached_property
+    def hook(self) -> AsyncKubernetesHook:
+        return self._get_async_hook()
 
     def define_container_state(self, pod: V1Pod) -> ContainerState:
         pod_containers = pod.status.container_statuses
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py 
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 1fbaef72a9..eb1194369e 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import asyncio
 import warnings
+from functools import cached_property
 from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence
 
 from google.cloud.container_v1.types import Operation
@@ -137,7 +138,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
             },
         )
 
-    def _get_async_hook(self) -> GKEPodAsyncHook:  # type: ignore[override]
+    @cached_property
+    def hook(self) -> GKEPodAsyncHook:  # type: ignore[override]
         return GKEPodAsyncHook(
             cluster_url=self._cluster_url,
             ssl_ca_cert=self._ssl_ca_cert,
diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py 
b/tests/providers/cncf/kubernetes/triggers/test_pod.py
index 42a0196ed7..9c016ea8cf 100644
--- a/tests/providers/cncf/kubernetes/triggers/test_pod.py
+++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py
@@ -94,9 +94,9 @@ class TestKubernetesPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_run_loop_return_success_event(self, mock_hook, mock_method, 
trigger):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.TERMINATED
 
         expected_event = TriggerEvent(
@@ -113,9 +113,9 @@ class TestKubernetesPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_run_loop_return_failed_event(self, mock_hook, mock_method, 
trigger):
-        mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+        mock_hook.get_pod.return_value = self._mock_pod_result(
             mock.MagicMock(
                 status=mock.MagicMock(
                     message=FAILED_RESULT_MSG,
@@ -138,9 +138,9 @@ class TestKubernetesPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_run_loop_return_waiting_event(self, mock_hook, mock_method, 
trigger, caplog):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.WAITING
 
         caplog.set_level(logging.INFO)
@@ -154,9 +154,9 @@ class TestKubernetesPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_run_loop_return_running_event(self, mock_hook, mock_method, 
trigger, caplog):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.RUNNING
 
         caplog.set_level(logging.INFO)
@@ -169,7 +169,7 @@ class TestKubernetesPodTrigger:
         assert f"Sleeping for {POLL_INTERVAL} seconds."
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def 
test_logging_in_trigger_when_exception_should_execute_successfully(
         self, mock_hook, trigger, caplog
     ):
@@ -177,7 +177,7 @@ class TestKubernetesPodTrigger:
         Test that KubernetesPodTrigger fires the correct event in case of an 
error.
         """
 
-        mock_hook.return_value.get_pod.side_effect = Exception("Test 
exception")
+        mock_hook.get_pod.side_effect = Exception("Test exception")
 
         generator = trigger.run()
         actual = await generator.asend(None)
@@ -192,7 +192,7 @@ class TestKubernetesPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_logging_in_trigger_when_fail_should_execute_successfully(
         self, mock_hook, mock_method, trigger, caplog
     ):
@@ -200,7 +200,7 @@ class TestKubernetesPodTrigger:
         Test that KubernetesPodTrigger fires the correct event in case of fail.
         """
 
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.FAILED
         caplog.set_level(logging.INFO)
 
@@ -209,7 +209,7 @@ class TestKubernetesPodTrigger:
         assert "Container logs:"
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def 
test_logging_in_trigger_when_cancelled_should_execute_successfully_and_delete_pod(
         self,
         mock_hook,
@@ -219,9 +219,9 @@ class TestKubernetesPodTrigger:
         Test that KubernetesPodTrigger fires the correct event in case if the 
task was cancelled.
         """
 
-        mock_hook.return_value.get_pod.side_effect = CancelledError()
-        mock_hook.return_value.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
-        mock_hook.return_value.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.side_effect = CancelledError()
+        mock_hook.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
 
         trigger = KubernetesPodTrigger(
             pod_name=POD_NAME,
@@ -255,7 +255,7 @@ class TestKubernetesPodTrigger:
         assert "Deleting pod..." in caplog.text
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def 
test_logging_in_trigger_when_cancelled_should_execute_successfully_without_delete_pod(
         self,
         mock_hook,
@@ -265,9 +265,9 @@ class TestKubernetesPodTrigger:
         Test that KubernetesPodTrigger fires the correct event if the task was 
cancelled.
         """
 
-        mock_hook.return_value.get_pod.side_effect = CancelledError()
-        mock_hook.return_value.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
-        mock_hook.return_value.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.side_effect = CancelledError()
+        mock_hook.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
 
         trigger = KubernetesPodTrigger(
             pod_name=POD_NAME,
@@ -341,12 +341,12 @@ class TestKubernetesPodTrigger:
     @pytest.mark.asyncio
     @pytest.mark.parametrize("container_state", [ContainerState.WAITING, 
ContainerState.UNDEFINED])
     @mock.patch(f"{TRIGGER_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_PATH}.hook")
     async def test_run_loop_return_timeout_event(
         self, mock_hook, mock_method, trigger, caplog, container_state
     ):
         trigger.trigger_start_time = TRIGGER_START_TIME - 
datetime.timedelta(minutes=2)
-        mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+        mock_hook.get_pod.return_value = self._mock_pod_result(
             mock.MagicMock(
                 status=mock.MagicMock(
                     phase=PodPhase.PENDING,
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py 
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index b28d9418ef..b252ea4e30 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -105,11 +105,11 @@ class TestGKEStartPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def test_run_loop_return_success_event_should_execute_successfully(
         self, mock_hook, mock_method, trigger
     ):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.TERMINATED
 
         expected_event = TriggerEvent(
@@ -126,11 +126,11 @@ class TestGKEStartPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def test_run_loop_return_failed_event_should_execute_successfully(
         self, mock_hook, mock_method, trigger
     ):
-        mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+        mock_hook.get_pod.return_value = self._mock_pod_result(
             mock.MagicMock(
                 status=mock.MagicMock(
                     message=FAILED_RESULT_MSG,
@@ -153,11 +153,11 @@ class TestGKEStartPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def test_run_loop_return_waiting_event_should_execute_successfully(
         self, mock_hook, mock_method, trigger, caplog
     ):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.WAITING
 
         caplog.set_level(logging.INFO)
@@ -171,11 +171,11 @@ class TestGKEStartPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def test_run_loop_return_running_event_should_execute_successfully(
         self, mock_hook, mock_method, trigger, caplog
     ):
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.RUNNING
 
         caplog.set_level(logging.INFO)
@@ -188,14 +188,14 @@ class TestGKEStartPodTrigger:
         assert f"Sleeping for {POLL_INTERVAL} seconds."
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def 
test_logging_in_trigger_when_exception_should_execute_successfully(
         self, mock_hook, trigger, caplog
     ):
         """
         Test that GKEStartPodTrigger fires the correct event in case of an 
error.
         """
-        mock_hook.return_value.get_pod.side_effect = Exception("Test 
exception")
+        mock_hook.get_pod.side_effect = Exception("Test exception")
 
         generator = trigger.run()
         actual = await generator.asend(None)
@@ -210,14 +210,14 @@ class TestGKEStartPodTrigger:
 
     @pytest.mark.asyncio
     @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def test_logging_in_trigger_when_fail_should_execute_successfully(
         self, mock_hook, mock_method, trigger, caplog
     ):
         """
         Test that GKEStartPodTrigger fires the correct event in case of fail.
         """
-        mock_hook.return_value.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
         mock_method.return_value = ContainerState.FAILED
         caplog.set_level(logging.INFO)
 
@@ -226,16 +226,16 @@ class TestGKEStartPodTrigger:
         assert "Container logs:"
 
     @pytest.mark.asyncio
-    @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+    @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
     async def 
test_logging_in_trigger_when_cancelled_should_execute_successfully(
         self, mock_hook, trigger, caplog
     ):
         """
         Test that GKEStartPodTrigger fires the correct event in case if the 
task was cancelled.
         """
-        mock_hook.return_value.get_pod.side_effect = CancelledError()
-        mock_hook.return_value.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
-        mock_hook.return_value.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.get_pod.side_effect = CancelledError()
+        mock_hook.read_logs.return_value = 
self._mock_pod_result(mock.MagicMock())
+        mock_hook.delete_pod.return_value = 
self._mock_pod_result(mock.MagicMock())
 
         generator = trigger.run()
         actual = await generator.asend(None)

Reply via email to