This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch mlnsharma/main
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 292c23283e3f97b8d281fbcd581c2f4d7abfa98e
Author: Narasimha Sharma <[email protected]>
AuthorDate: Mon Jan 16 22:10:55 2023 -0800

    Read logs from all containers in KPO
---
 .../cncf/kubernetes/operators/kubernetes_pod.py    | 13 +++-
 .../providers/cncf/kubernetes/utils/pod_manager.py | 87 ++++++++++++++++++++--
 kubernetes_tests/test_kubernetes_pod_operator.py   |  2 +-
 .../cncf/kubernetes/utils/test_pod_manager.py      | 23 ++++++
 4 files changed, 117 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index c34fc5c02b..96fb98fb53 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -168,6 +168,9 @@ class KubernetesPodOperator(BaseOperator):
     :param labels: labels to apply to the Pod. (templated)
     :param startup_timeout_seconds: timeout in seconds to startup the pod.
     :param get_logs: get the stdout of the container as logs of the tasks.
+    :param log_containers: list of container names or bool value to collect 
logs.
+        If bool value is True, all container logs are collected,
+        if False, only 'base' container logs are collected.
     :param image_pull_policy: Specify a policy to cache or always pull an 
image.
     :param annotations: non-identifying metadata you can attach to the Pod.
         Can be a large range of data, and can include characters
@@ -248,6 +251,7 @@ class KubernetesPodOperator(BaseOperator):
         reattach_on_restart: bool = True,
         startup_timeout_seconds: int = 120,
         get_logs: bool = True,
+        log_containers: list[str] | bool = False,
         image_pull_policy: str | None = None,
         annotations: dict | None = None,
         container_resources: k8s.V1ResourceRequirements | None = None,
@@ -311,6 +315,7 @@ class KubernetesPodOperator(BaseOperator):
         self.cluster_context = cluster_context
         self.reattach_on_restart = reattach_on_restart
         self.get_logs = get_logs
+        self.log_containers = log_containers
         self.image_pull_policy = image_pull_policy
         self.node_selector = node_selector or {}
         self.annotations = annotations or {}
@@ -496,9 +501,13 @@ class KubernetesPodOperator(BaseOperator):
             self.await_pod_start(pod=self.pod)
 
             if self.get_logs:
-                self.pod_manager.fetch_container_logs(
+                # if log_containers is False, fetch logs from base container, 
otherwise
+                # fetch logs for all containers or for the specified input 
list of container names
+                self.pod_manager.fetch_input_container_logs(
                     pod=self.pod,
-                    container_name=self.BASE_CONTAINER_NAME,
+                    log_containers=(
+                        self.log_containers if self.log_containers else 
[self.BASE_CONTAINER_NAME]
+                    ),
                     follow=True,
                 )
             else:
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 56ef95eef0..1d1b2f06a6 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -84,6 +84,20 @@ def container_is_running(pod: V1Pod, container_name: str) -> 
bool:
     return container_status.state.running is not None
 
 
+def container_is_completed(pod: V1Pod, container_name: str) -> bool:
+    """
+    Examines V1Pod ``pod`` to determine whether ``container_name`` is running.
+    If that container is present and completed, returns True.  Returns False 
otherwise.
+    """
+    container_statuses = pod.status.container_statuses if pod and pod.status 
else None
+    if not container_statuses:
+        return False
+    container_status = next((status for status in container_statuses if 
status.name == container_name), None)
+    if not container_status:
+        return False
+    return container_status.state.terminated is not None
+
+
 def get_container_termination_message(pod: V1Pod, container_name: str):
     with suppress(AttributeError, TypeError):
         container_statuses = pod.status.container_statuses
@@ -264,15 +278,68 @@ class PodManager(LoggingMixin):
                 )
                 time.sleep(1)
 
-    def await_container_completion(self, pod: V1Pod, container_name: str) -> 
None:
+    def fetch_input_container_logs(
+        self, pod: V1Pod, log_containers: list[str] | bool, follow=False
+    ) -> list[PodLoggingStatus]:
+        """
+        Follow the logs of containers in the pod specified by input parameter 
and stream to airflow logging.
+        Returns when all the containers exit
         """
-        Waits for the given container in the given pod to be completed
+        pod_logging_statuses = []
+        all_container_names = self.get_container_names(pod)
+        if len(all_container_names) == 0:
+            self.log.error("Failed to retrieve container names from the pod, 
unable to collect logs")
+        else:
+            # if log_containers is list type, collect logs of the input 
container names
+            if type(log_containers) == list:
+                for container_name in log_containers:
+                    if container_name in all_container_names:
+                        status = self.fetch_container_logs(
+                            pod=pod, container_name=container_name, 
follow=follow
+                        )
+                        pod_logging_statuses.append(status)
+                    else:
+                        self.log.error(
+                            "container name '%s' specified in input parameter 
is not found in the pod",
+                            container_name,
+                        )
+            # if log_containers is bool value True, collect logs from all 
containers
+            # if log_containers is bool value False, collect logs from the 
base (first) container
+            elif type(log_containers) == bool:
+                if log_containers is True:
+                    for container_name in all_container_names:
+                        status = self.fetch_container_logs(
+                            pod=pod, container_name=container_name, 
follow=follow
+                        )
+                        pod_logging_statuses.append(status)
+                else:
+                    status = self.fetch_container_logs(
+                        pod=pod, container_name=all_container_names[0], 
follow=follow
+                    )
+                    pod_logging_statuses.append(status)
+            else:
+                self.log.error(
+                    "Invalid type '%s' specified for container names input 
parameter", type(log_containers)
+                )
 
-        :param pod: pod spec that will be monitored
-        :param container_name: name of the container within the pod to monitor
+        return pod_logging_statuses
+
+    def await_container_completion(self, pod: V1Pod, container_name: str) -> 
bool:
+        """
+        Examines V1Pod ``pod`` to determine whether ``container_name`` is 
running.
+        If that container is present and completed, returns True.  Returns 
False otherwise.
+        :param pod: pod spec
+        :param container_name: container name that will be monitored
+        :return: tuple[str, str | None]
         """
-        while self.container_is_running(pod=pod, 
container_name=container_name):
+        while True:
+            remote_pod = self.read_pod(pod)
+            terminated = container_is_completed(remote_pod, container_name)
+            if terminated:
+                break
+            self.log.info("Waiting for container '%s' state to be Terminated", 
container_name)
             time.sleep(1)
+        return terminated
 
     def await_pod_completion(self, pod: V1Pod) -> V1Pod:
         """
@@ -350,6 +417,16 @@ class PodManager(LoggingMixin):
             self.log.exception("There was an error reading the kubernetes 
API.")
             raise
 
+    @tenacity.retry(stop=tenacity.stop_after_attempt(3), 
wait=tenacity.wait_exponential(), reraise=True)
+    def get_container_names(self, pod: V1Pod) -> list[str]:
+        """Return Container names from POD except for the airflow-xcom-sidecar 
container"""
+        container_names = []
+        pod_info = self.read_pod(pod)
+        for container_spec in pod_info.spec.containers:
+            if container_spec.name != "airflow-xcom-sidecar":
+                container_names.append(container_spec.name)
+        return container_names
+
     @tenacity.retry(stop=tenacity.stop_after_attempt(3), 
wait=tenacity.wait_exponential(), reraise=True)
     def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
         """Reads events from the POD"""
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index caa0ad87b5..56bb764e0c 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -496,7 +496,7 @@ class TestKubernetesPodOperatorSystem:
             )
             context = create_context(k)
             k.execute(context=context)
-            mock_logger.info.assert_any_call("retrieved from mount")
+            mock_logger.info.assert_any_call("[%s] %s", "base", "retrieved 
from mount")
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod["spec"]["containers"][0]["args"] = args
             self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py 
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index a089e86119..1cd255e6a3 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -302,6 +302,29 @@ class TestPodManager:
         assert ret.last_log_time == DateTime(2021, 1, 1, 
tzinfo=Timezone("UTC"))
         assert ret.running is False
 
+    @pytest.mark.parametrize("follow", [True, False])
+    @pytest.mark.parametrize("log_containers", ["base", "alpine"])
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
+    def test_fetch_input_container_logs(self, container_running, 
log_containers, follow):
+        mock_pod = MagicMock()
+        self.pod_manager.read_pod = MagicMock()
+        self.pod_manager.get_container_names = MagicMock()
+        container_running.return_value = False
+        self.mock_kube_client.read_namespaced_pod_log.return_value = 
[b"2021-01-01 hi"]
+        ret_values = self.pod_manager.fetch_input_container_logs(
+            pod=mock_pod, log_containers=log_containers, follow=follow
+        )
+        for ret in ret_values:
+            assert ret.last_log_time == DateTime(2021, 1, 1, 
tzinfo=Timezone("UTC"))
+            assert ret.running is False
+
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_completed")
+    def test_await_container_completion(self, container_completed):
+        mock_pod = MagicMock()
+        container_completed.return_value = True
+        status_completed = 
self.pod_manager.await_container_completion(pod=mock_pod, container_name="base")
+        assert status_completed is True
+
     @mock.patch("pendulum.now")
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
     def test_fetch_container_since_time(self, container_running, mock_now):

Reply via email to