This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5ab43d5541 Move KubernetesPodTrigger hook to a cached property (#36290)
5ab43d5541 is described below
commit 5ab43d5541a68c5c90fe849f19e344bcdeddd44f
Author: Hussein Awala <[email protected]>
AuthorDate: Tue Dec 19 09:11:52 2023 +0100
Move KubernetesPodTrigger hook to a cached property (#36290)
---
airflow/providers/cncf/kubernetes/triggers/pod.py | 28 +++++++-------
.../google/cloud/triggers/kubernetes_engine.py | 4 +-
.../providers/cncf/kubernetes/triggers/test_pod.py | 44 +++++++++++-----------
.../cloud/triggers/test_kubernetes_engine.py | 32 ++++++++--------
4 files changed, 56 insertions(+), 52 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py
b/airflow/providers/cncf/kubernetes/triggers/pod.py
index b7f0348b66..3dd9eb173c 100644
--- a/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ b/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -22,6 +22,7 @@ import traceback
import warnings
from asyncio import CancelledError
from enum import Enum
+from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator
from airflow.exceptions import AirflowProviderDeprecationWarning
@@ -116,7 +117,6 @@ class KubernetesPodTrigger(BaseTrigger):
self.on_finish_action = OnFinishAction(on_finish_action)
self.should_delete_pod = self.on_finish_action ==
OnFinishAction.DELETE_POD
- self._hook: AsyncKubernetesHook | None = None
self._since_time = None
def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -142,11 +142,10 @@ class KubernetesPodTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]: # type:
ignore[override]
"""Get current pod status and yield a TriggerEvent."""
- hook = self._get_async_hook()
self.log.info("Checking pod %r in namespace %r.", self.pod_name,
self.pod_namespace)
try:
while True:
- pod = await hook.get_pod(
+ pod = await self.hook.get_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
@@ -206,13 +205,13 @@ class KubernetesPodTrigger(BaseTrigger):
# That means that task was marked as failed
if self.get_logs:
self.log.info("Outputting container logs...")
- await self._get_async_hook().read_logs(
+ await self.hook.read_logs(
name=self.pod_name,
namespace=self.pod_namespace,
)
if self.on_finish_action == OnFinishAction.DELETE_POD:
self.log.info("Deleting pod...")
- await self._get_async_hook().delete_pod(
+ await self.hook.delete_pod(
name=self.pod_name,
namespace=self.pod_namespace,
)
@@ -237,14 +236,17 @@ class KubernetesPodTrigger(BaseTrigger):
)
def _get_async_hook(self) -> AsyncKubernetesHook:
- if self._hook is None:
- self._hook = AsyncKubernetesHook(
- conn_id=self.kubernetes_conn_id,
- in_cluster=self.in_cluster,
- config_file=self.config_file,
- cluster_context=self.cluster_context,
- )
- return self._hook
+ # TODO: Remove this method when the min version of kubernetes provider
is 7.12.0 in Google provider.
+ return AsyncKubernetesHook(
+ conn_id=self.kubernetes_conn_id,
+ in_cluster=self.in_cluster,
+ config_file=self.config_file,
+ cluster_context=self.cluster_context,
+ )
+
+ @cached_property
+ def hook(self) -> AsyncKubernetesHook:
+ return self._get_async_hook()
def define_container_state(self, pod: V1Pod) -> ContainerState:
pod_containers = pod.status.container_statuses
diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 1fbaef72a9..eb1194369e 100644
--- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import asyncio
import warnings
+from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence
from google.cloud.container_v1.types import Operation
@@ -137,7 +138,8 @@ class GKEStartPodTrigger(KubernetesPodTrigger):
},
)
- def _get_async_hook(self) -> GKEPodAsyncHook: # type: ignore[override]
+ @cached_property
+ def hook(self) -> GKEPodAsyncHook: # type: ignore[override]
return GKEPodAsyncHook(
cluster_url=self._cluster_url,
ssl_ca_cert=self._ssl_ca_cert,
diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py
b/tests/providers/cncf/kubernetes/triggers/test_pod.py
index 42a0196ed7..9c016ea8cf 100644
--- a/tests/providers/cncf/kubernetes/triggers/test_pod.py
+++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py
@@ -94,9 +94,9 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_success_event(self, mock_hook, mock_method,
trigger):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED
expected_event = TriggerEvent(
@@ -113,9 +113,9 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_failed_event(self, mock_hook, mock_method,
trigger):
- mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+ mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
message=FAILED_RESULT_MSG,
@@ -138,9 +138,9 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_waiting_event(self, mock_hook, mock_method,
trigger, caplog):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.WAITING
caplog.set_level(logging.INFO)
@@ -154,9 +154,9 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_running_event(self, mock_hook, mock_method,
trigger, caplog):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.RUNNING
caplog.set_level(logging.INFO)
@@ -169,7 +169,7 @@ class TestKubernetesPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def
test_logging_in_trigger_when_exception_should_execute_successfully(
self, mock_hook, trigger, caplog
):
@@ -177,7 +177,7 @@ class TestKubernetesPodTrigger:
Test that KubernetesPodTrigger fires the correct event in case of an
error.
"""
- mock_hook.return_value.get_pod.side_effect = Exception("Test
exception")
+ mock_hook.get_pod.side_effect = Exception("Test exception")
generator = trigger.run()
actual = await generator.asend(None)
@@ -192,7 +192,7 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_logging_in_trigger_when_fail_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
@@ -200,7 +200,7 @@ class TestKubernetesPodTrigger:
Test that KubernetesPodTrigger fires the correct event in case of fail.
"""
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.FAILED
caplog.set_level(logging.INFO)
@@ -209,7 +209,7 @@ class TestKubernetesPodTrigger:
assert "Container logs:"
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def
test_logging_in_trigger_when_cancelled_should_execute_successfully_and_delete_pod(
self,
mock_hook,
@@ -219,9 +219,9 @@ class TestKubernetesPodTrigger:
Test that KubernetesPodTrigger fires the correct event in case if the
task was cancelled.
"""
- mock_hook.return_value.get_pod.side_effect = CancelledError()
- mock_hook.return_value.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
- mock_hook.return_value.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.side_effect = CancelledError()
+ mock_hook.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
@@ -255,7 +255,7 @@ class TestKubernetesPodTrigger:
assert "Deleting pod..." in caplog.text
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def
test_logging_in_trigger_when_cancelled_should_execute_successfully_without_delete_pod(
self,
mock_hook,
@@ -265,9 +265,9 @@ class TestKubernetesPodTrigger:
Test that KubernetesPodTrigger fires the correct event if the task was
cancelled.
"""
- mock_hook.return_value.get_pod.side_effect = CancelledError()
- mock_hook.return_value.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
- mock_hook.return_value.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.side_effect = CancelledError()
+ mock_hook.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
trigger = KubernetesPodTrigger(
pod_name=POD_NAME,
@@ -341,12 +341,12 @@ class TestKubernetesPodTrigger:
@pytest.mark.asyncio
@pytest.mark.parametrize("container_state", [ContainerState.WAITING,
ContainerState.UNDEFINED])
@mock.patch(f"{TRIGGER_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_PATH}.hook")
async def test_run_loop_return_timeout_event(
self, mock_hook, mock_method, trigger, caplog, container_state
):
trigger.trigger_start_time = TRIGGER_START_TIME -
datetime.timedelta(minutes=2)
- mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+ mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
phase=PodPhase.PENDING,
diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
index b28d9418ef..b252ea4e30 100644
--- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
+++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py
@@ -105,11 +105,11 @@ class TestGKEStartPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_success_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.TERMINATED
expected_event = TriggerEvent(
@@ -126,11 +126,11 @@ class TestGKEStartPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_failed_event_should_execute_successfully(
self, mock_hook, mock_method, trigger
):
- mock_hook.return_value.get_pod.return_value = self._mock_pod_result(
+ mock_hook.get_pod.return_value = self._mock_pod_result(
mock.MagicMock(
status=mock.MagicMock(
message=FAILED_RESULT_MSG,
@@ -153,11 +153,11 @@ class TestGKEStartPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_waiting_event_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.WAITING
caplog.set_level(logging.INFO)
@@ -171,11 +171,11 @@ class TestGKEStartPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_run_loop_return_running_event_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.RUNNING
caplog.set_level(logging.INFO)
@@ -188,14 +188,14 @@ class TestGKEStartPodTrigger:
assert f"Sleeping for {POLL_INTERVAL} seconds."
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def
test_logging_in_trigger_when_exception_should_execute_successfully(
self, mock_hook, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case of an
error.
"""
- mock_hook.return_value.get_pod.side_effect = Exception("Test
exception")
+ mock_hook.get_pod.side_effect = Exception("Test exception")
generator = trigger.run()
actual = await generator.asend(None)
@@ -210,14 +210,14 @@ class TestGKEStartPodTrigger:
@pytest.mark.asyncio
@mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def test_logging_in_trigger_when_fail_should_execute_successfully(
self, mock_hook, mock_method, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case of fail.
"""
- mock_hook.return_value.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.return_value =
self._mock_pod_result(mock.MagicMock())
mock_method.return_value = ContainerState.FAILED
caplog.set_level(logging.INFO)
@@ -226,16 +226,16 @@ class TestGKEStartPodTrigger:
assert "Container logs:"
@pytest.mark.asyncio
- @mock.patch(f"{TRIGGER_GKE_PATH}._get_async_hook")
+ @mock.patch(f"{TRIGGER_GKE_PATH}.hook")
async def
test_logging_in_trigger_when_cancelled_should_execute_successfully(
self, mock_hook, trigger, caplog
):
"""
Test that GKEStartPodTrigger fires the correct event in case if the
task was cancelled.
"""
- mock_hook.return_value.get_pod.side_effect = CancelledError()
- mock_hook.return_value.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
- mock_hook.return_value.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.get_pod.side_effect = CancelledError()
+ mock_hook.read_logs.return_value =
self._mock_pod_result(mock.MagicMock())
+ mock_hook.delete_pod.return_value =
self._mock_pod_result(mock.MagicMock())
generator = trigger.run()
actual = await generator.asend(None)