This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b5057e0e1f Add `progress_callback` parameter to 
`KubernetesPodOperator` (#34153)
b5057e0e1f is described below

commit b5057e0e1fc6b7a47e38037a97cac862706747f0
Author: zeotuan <[email protected]>
AuthorDate: Sun Sep 10 04:08:29 2023 +1000

    Add `progress_callback` parameter to `KubernetesPodOperator` (#34153)
    
    * add k8sPodOperator progress_callback
---
 airflow/providers/cncf/kubernetes/operators/pod.py     |  7 +++++--
 airflow/providers/cncf/kubernetes/utils/pod_manager.py |  7 ++++++-
 kubernetes_tests/test_kubernetes_pod_operator.py       | 18 ++++++++++++++++++
 .../cncf/kubernetes/utils/test_pod_manager.py          | 18 +++++++++++++++++-
 4 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 3f3438b1b7..5f7269718f 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -28,7 +28,7 @@ import warnings
 from collections.abc import Container
 from contextlib import AbstractContextManager
 from functools import cached_property
-from typing import TYPE_CHECKING, Any, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
 
 from kubernetes.client import CoreV1Api, V1Pod, models as k8s
 from kubernetes.stream import stream
@@ -245,6 +245,7 @@ class KubernetesPodOperator(BaseOperator):
         Default value is "File"
     :param active_deadline_seconds: The active_deadline_seconds which matches 
to active_deadline_seconds
         in V1PodSpec.
+    :param progress_callback: Callback function for receiving k8s container 
logs.
     """
 
     # This field can be overloaded at the instance level via 
base_container_name
@@ -328,6 +329,7 @@ class KubernetesPodOperator(BaseOperator):
         is_delete_operator_pod: None | bool = None,
         termination_message_policy: str = "File",
         active_deadline_seconds: int | None = None,
+        progress_callback: Callable[[str], None] | None = None,
         **kwargs,
     ) -> None:
         # TODO: remove in provider 6.0.0 release. This is a mitigate step to 
advise users to switch to the
@@ -428,6 +430,7 @@ class KubernetesPodOperator(BaseOperator):
         self.active_deadline_seconds = active_deadline_seconds
 
         self._config_dict: dict | None = None  # TODO: remove it when removing 
convert_config_file_to_dict
+        self._progress_callback = progress_callback
 
     @cached_property
     def _incluster_namespace(self):
@@ -505,7 +508,7 @@ class KubernetesPodOperator(BaseOperator):
 
     @cached_property
     def pod_manager(self) -> PodManager:
-        return PodManager(kube_client=self.client)
+        return PodManager(kube_client=self.client, 
progress_callback=self._progress_callback)
 
     @cached_property
     def hook(self) -> PodOperatorHookProtocol:
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 46d1d4bcef..142528ddad 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -28,7 +28,7 @@ from collections.abc import Iterable
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from datetime import timedelta
-from typing import TYPE_CHECKING, Generator, Protocol, cast
+from typing import TYPE_CHECKING, Callable, Generator, Protocol, cast
 
 import pendulum
 import tenacity
@@ -282,14 +282,17 @@ class PodManager(LoggingMixin):
     def __init__(
         self,
         kube_client: client.CoreV1Api,
+        progress_callback: Callable[[str], None] | None = None,
     ):
         """
         Creates the launcher.
 
         :param kube_client: kubernetes client
+        :param progress_callback: Callback function invoked when fetching 
container log.
         """
         super().__init__()
         self._client = kube_client
+        self._progress_callback = progress_callback
         self._watch = watch.Watch()
 
     def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
@@ -413,6 +416,8 @@ class PodManager(LoggingMixin):
                 for raw_line in logs:
                     line = raw_line.decode("utf-8", errors="backslashreplace")
                     line_timestamp, message = self.parse_log_line(line)
+                    if self._progress_callback:
+                        self._progress_callback(line)
                     if line_timestamp is not None:
                         last_captured_timestamp = line_timestamp
                     self.log.info("[%s] %s", container_name, message)
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index aba1f95457..7534cec5f0 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -1211,6 +1211,24 @@ class TestKubernetesPodOperatorSystem:
         self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
         assert self.expected_pod["spec"] == actual_pod["spec"]
 
+    def test_progess_call(self, mock_get_connection):
+        progress_callback = MagicMock()
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["echo 10"],
+            labels=self.labels,
+            task_id=str(uuid4()),
+            in_cluster=False,
+            do_xcom_push=False,
+            get_logs=True,
+            progress_callback=progress_callback,
+        )
+        context = create_context(k)
+        k.execute(context)
+        progress_callback.assert_called()
+
     def test_changing_base_container_name_no_logs(self, mock_get_connection):
         """
         This test checks BOTH a modified base container name AND the 
get_logs=False flow,
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py 
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index a49d90674d..dfe06d9a74 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -46,8 +46,11 @@ from airflow.utils.timezone import utc
 
 class TestPodManager:
     def setup_method(self):
+        self.mock_progress_callback = mock.Mock()
         self.mock_kube_client = mock.Mock()
-        self.pod_manager = PodManager(kube_client=self.mock_kube_client)
+        self.pod_manager = PodManager(
+            kube_client=self.mock_kube_client, 
progress_callback=self.mock_progress_callback
+        )
 
     def test_read_pod_logs_successfully_returns_logs(self):
         mock.sentinel.metadata = mock.MagicMock()
@@ -268,6 +271,19 @@ class TestPodManager:
 
         assert status.last_log_time == cast(DateTime, 
pendulum.parse(timestamp_string))
 
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")
+    def test_fetch_container_logs_invoke_progress_callback(
+        self, mock_read_pod_logs, mock_container_is_running
+    ):
+        message = "2020-10-08T14:16:17.793417674Z message"
+        no_ts_message = "notimestamp"
+        mock_read_pod_logs.return_value = [bytes(message, "utf-8"), 
bytes(no_ts_message, "utf-8")]
+        mock_container_is_running.return_value = False
+
+        self.pod_manager.fetch_container_logs(mock.MagicMock(), 
mock.MagicMock(), follow=True)
+        self.mock_progress_callback.assert_has_calls([mock.call(message), 
mock.call(no_ts_message)])
+
     def test_parse_invalid_log_line(self, caplog):
         with caplog.at_level(logging.INFO):
             
self.pod_manager.parse_log_line("2020-10-08T14:16:17.793417674ZInvalidmessage\n")

Reply via email to