This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b5057e0e1f Add `progress_callback` parameter to
`KubernetesPodOperator` (#34153)
b5057e0e1f is described below
commit b5057e0e1fc6b7a47e38037a97cac862706747f0
Author: zeotuan <[email protected]>
AuthorDate: Sun Sep 10 04:08:29 2023 +1000
Add `progress_callback` parameter to `KubernetesPodOperator` (#34153)
* add k8sPodOperator progress_callback
---
airflow/providers/cncf/kubernetes/operators/pod.py | 7 +++++--
airflow/providers/cncf/kubernetes/utils/pod_manager.py | 7 ++++++-
kubernetes_tests/test_kubernetes_pod_operator.py | 18 ++++++++++++++++++
.../cncf/kubernetes/utils/test_pod_manager.py | 18 +++++++++++++++++-
4 files changed, 46 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 3f3438b1b7..5f7269718f 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -28,7 +28,7 @@ import warnings
from collections.abc import Container
from contextlib import AbstractContextManager
from functools import cached_property
-from typing import TYPE_CHECKING, Any, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.stream import stream
@@ -245,6 +245,7 @@ class KubernetesPodOperator(BaseOperator):
Default value is "File"
:param active_deadline_seconds: The active_deadline_seconds which matches
to active_deadline_seconds
in V1PodSpec.
+ :param progress_callback: Callback function for receiving k8s container
logs.
"""
# This field can be overloaded at the instance level via
base_container_name
@@ -328,6 +329,7 @@ class KubernetesPodOperator(BaseOperator):
is_delete_operator_pod: None | bool = None,
termination_message_policy: str = "File",
active_deadline_seconds: int | None = None,
+ progress_callback: Callable[[str], None] | None = None,
**kwargs,
) -> None:
# TODO: remove in provider 6.0.0 release. This is a mitigate step to
advise users to switch to the
@@ -428,6 +430,7 @@ class KubernetesPodOperator(BaseOperator):
self.active_deadline_seconds = active_deadline_seconds
self._config_dict: dict | None = None # TODO: remove it when removing
convert_config_file_to_dict
+ self._progress_callback = progress_callback
@cached_property
def _incluster_namespace(self):
@@ -505,7 +508,7 @@ class KubernetesPodOperator(BaseOperator):
@cached_property
def pod_manager(self) -> PodManager:
- return PodManager(kube_client=self.client)
+ return PodManager(kube_client=self.client,
progress_callback=self._progress_callback)
@cached_property
def hook(self) -> PodOperatorHookProtocol:
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 46d1d4bcef..142528ddad 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -28,7 +28,7 @@ from collections.abc import Iterable
from contextlib import closing, suppress
from dataclasses import dataclass
from datetime import timedelta
-from typing import TYPE_CHECKING, Generator, Protocol, cast
+from typing import TYPE_CHECKING, Callable, Generator, Protocol, cast
import pendulum
import tenacity
@@ -282,14 +282,17 @@ class PodManager(LoggingMixin):
def __init__(
self,
kube_client: client.CoreV1Api,
+ progress_callback: Callable[[str], None] | None = None,
):
"""
Creates the launcher.
:param kube_client: kubernetes client
+ :param progress_callback: Callback function invoked when fetching
container log.
"""
super().__init__()
self._client = kube_client
+ self._progress_callback = progress_callback
self._watch = watch.Watch()
def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
@@ -413,6 +416,8 @@ class PodManager(LoggingMixin):
for raw_line in logs:
line = raw_line.decode("utf-8", errors="backslashreplace")
line_timestamp, message = self.parse_log_line(line)
+ if self._progress_callback:
+ self._progress_callback(line)
if line_timestamp is not None:
last_captured_timestamp = line_timestamp
self.log.info("[%s] %s", container_name, message)
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py
b/kubernetes_tests/test_kubernetes_pod_operator.py
index aba1f95457..7534cec5f0 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -1211,6 +1211,24 @@ class TestKubernetesPodOperatorSystem:
self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
assert self.expected_pod["spec"] == actual_pod["spec"]
+ def test_progess_call(self, mock_get_connection):
+ progress_callback = MagicMock()
+ k = KubernetesPodOperator(
+ namespace="default",
+ image="ubuntu:16.04",
+ cmds=["bash", "-cx"],
+ arguments=["echo 10"],
+ labels=self.labels,
+ task_id=str(uuid4()),
+ in_cluster=False,
+ do_xcom_push=False,
+ get_logs=True,
+ progress_callback=progress_callback,
+ )
+ context = create_context(k)
+ k.execute(context)
+ progress_callback.assert_called()
+
def test_changing_base_container_name_no_logs(self, mock_get_connection):
"""
This test checks BOTH a modified base container name AND the
get_logs=False flow,
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index a49d90674d..dfe06d9a74 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -46,8 +46,11 @@ from airflow.utils.timezone import utc
class TestPodManager:
def setup_method(self):
+ self.mock_progress_callback = mock.Mock()
self.mock_kube_client = mock.Mock()
- self.pod_manager = PodManager(kube_client=self.mock_kube_client)
+ self.pod_manager = PodManager(
+ kube_client=self.mock_kube_client,
progress_callback=self.mock_progress_callback
+ )
def test_read_pod_logs_successfully_returns_logs(self):
mock.sentinel.metadata = mock.MagicMock()
@@ -268,6 +271,19 @@ class TestPodManager:
assert status.last_log_time == cast(DateTime,
pendulum.parse(timestamp_string))
+
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
+
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
+ def test_fetch_container_logs_invoke_progress_callback(
+ self, mock_read_pod_logs, mock_container_is_running
+ ):
+ message = "2020-10-08T14:16:17.793417674Z message"
+ no_ts_message = "notimestamp"
+ mock_read_pod_logs.return_value = [bytes(message, "utf-8"),
bytes(no_ts_message, "utf-8")]
+ mock_container_is_running.return_value = False
+
+ self.pod_manager.fetch_container_logs(mock.MagicMock(),
mock.MagicMock(), follow=True)
+ self.mock_progress_callback.assert_has_calls([mock.call(message),
mock.call(no_ts_message)])
+
def test_parse_invalid_log_line(self, caplog):
with caplog.at_level(logging.INFO):
self.pod_manager.parse_log_line("2020-10-08T14:16:17.793417674ZInvalidmessage\n")