This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a0f41ba53 Logging from all containers in KubernetesOperatorPod 
(#31663)
9a0f41ba53 is described below

commit 9a0f41ba53185031bc2aa56ead2928ae4b20de99
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jul 6 15:19:40 2023 +0530

    Logging from all containers in KubernetesOperatorPod (#31663)
    
    * Logging from all containers in KubernetesOperatorPod
    
    * review comments from uranusjs
    
    * Fixing init logic
    
    * Fixing docs
    
    * Fixing sphinx logging
    
    * Addressing review comments from vincbeck
    
    * nits from potiuk
    
    * nits from potiuk
    
    * reverting return type
    
    * comment from uranusjr
    
    * fixing tests
    
    * fixing tests
    
    * review comments from hussein
    
    * handling nits from hussein
    
    * fixing tests
    
    ---------
    
    Co-authored-by: Amogh <[email protected]>
    Co-authored-by: Amogh Desai <[email protected]>
---
 airflow/providers/cncf/kubernetes/operators/pod.py | 20 +++--
 .../providers/cncf/kubernetes/utils/pod_manager.py | 95 +++++++++++++++++++++-
 kubernetes_tests/test_kubernetes_pod_operator.py   |  2 +-
 .../cncf/kubernetes/utils/test_pod_manager.py      | 40 +++++++++
 4 files changed, 147 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index e3ac7708e7..124685e792 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -27,7 +27,7 @@ import warnings
 from collections.abc import Container
 from contextlib import AbstractContextManager
 from functools import cached_property
-from typing import TYPE_CHECKING, Any, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
 from kubernetes.client import CoreV1Api, models as k8s
 from slugify import slugify
@@ -62,6 +62,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager 
import (
     get_container_termination_message,
 )
 from airflow.settings import pod_mutation_hook
+from airflow.typing_compat import Literal
 from airflow.utils import yaml
 from airflow.utils.helpers import prune_dict, validate_key
 from airflow.utils.timezone import utcnow
@@ -178,6 +179,10 @@ class KubernetesPodOperator(BaseOperator):
     :param labels: labels to apply to the Pod. (templated)
     :param startup_timeout_seconds: timeout in seconds to startup the pod.
     :param get_logs: get the stdout of the base container as logs of the tasks.
+    :param container_logs: list of containers whose logs will be published to 
stdout
+        Takes a sequence of containers, a single container name or True. If 
True,
+        all the containers logs are published. Works in conjunction with 
get_logs param.
+        The default value is the base container.
     :param image_pull_policy: Specify a policy to cache or always pull an 
image.
     :param annotations: non-identifying metadata you can attach to the Pod.
         Can be a large range of data, and can include characters
@@ -278,6 +283,7 @@ class KubernetesPodOperator(BaseOperator):
         reattach_on_restart: bool = True,
         startup_timeout_seconds: int = 120,
         get_logs: bool = True,
+        container_logs: Iterable[str] | str | Literal[True] = 
BASE_CONTAINER_NAME,
         image_pull_policy: str | None = None,
         annotations: dict | None = None,
         container_resources: k8s.V1ResourceRequirements | None = None,
@@ -350,6 +356,11 @@ class KubernetesPodOperator(BaseOperator):
         self.cluster_context = cluster_context
         self.reattach_on_restart = reattach_on_restart
         self.get_logs = get_logs
+        self.container_logs = container_logs
+        if self.container_logs == KubernetesPodOperator.BASE_CONTAINER_NAME:
+            self.container_logs = (
+                base_container_name if base_container_name else 
KubernetesPodOperator.BASE_CONTAINER_NAME
+            )
         self.image_pull_policy = image_pull_policy
         self.node_selector = node_selector or {}
         self.annotations = annotations or {}
@@ -572,11 +583,10 @@ class KubernetesPodOperator(BaseOperator):
             self.await_pod_start(pod=self.pod)
 
             if self.get_logs:
-                self.pod_manager.fetch_container_logs(
+                self.pod_manager.fetch_requested_container_logs(
                     pod=self.pod,
-                    container_name=self.base_container_name,
-                    follow=True,
-                    post_termination_timeout=self.POST_TERMINATION_TIMEOUT,
+                    container_logs=self.container_logs,
+                    follow_logs=True,
                 )
             else:
                 self.pod_manager.await_container_completion(
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 2251f1d438..979180f1d9 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -23,6 +23,7 @@ import logging
 import math
 import time
 import warnings
+from collections.abc import Iterable
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from datetime import datetime, timedelta
@@ -43,7 +44,7 @@ from urllib3.response import HTTPResponse
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.kubernetes.pod_generator import PodDefaults
-from airflow.typing_compat import Protocol
+from airflow.typing_compat import Literal, Protocol
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.timezone import utcnow
 
@@ -125,6 +126,17 @@ def container_is_running(pod: V1Pod, container_name: str) 
-> bool:
     return container_status.state.running is not None
 
 
+def container_is_completed(pod: V1Pod, container_name: str) -> bool:
+    """
+    Examines V1Pod ``pod`` to determine whether ``container_name`` is 
completed.
+    If that container is present and completed, returns True.  Returns False 
otherwise.
+    """
+    container_status = get_container_status(pod, container_name)
+    if not container_status:
+        return False
+    return container_status.state.terminated is not None
+
+
 def container_is_terminated(pod: V1Pod, container_name: str) -> bool:
     """
     Examines V1Pod ``pod`` to determine whether ``container_name`` is 
terminated.
@@ -379,11 +391,12 @@ class PodManager(LoggingMixin):
                 for raw_line in logs:
                     line = raw_line.decode("utf-8", errors="backslashreplace")
                     timestamp, message = self.parse_log_line(line)
-                    self.log.info(message)
+                    self.log.info("[%s] %s", container_name, message)
             except BaseHTTPError as e:
                 self.log.warning(
-                    "Reading of logs interrupted with error %r; will retry. "
+                    "Reading of logs interrupted for container %r with error 
%r; will retry. "
                     "Set log level to DEBUG for traceback.",
+                    container_name,
                     e,
                 )
                 self.log.debug(
@@ -413,6 +426,65 @@ class PodManager(LoggingMixin):
                 )
                 time.sleep(1)
 
+    def fetch_requested_container_logs(
+        self, pod: V1Pod, container_logs: Iterable[str] | str | Literal[True], 
follow_logs=False
+    ) -> list[PodLoggingStatus]:
+        """
+        Follow the logs of containers in the pod specified by input parameter 
and publish
+        it to airflow logging. Returns when all the containers exit.
+        """
+        pod_logging_statuses = []
+        all_containers = self.get_container_names(pod)
+        if len(all_containers) == 0:
+            self.log.error("Could not retrieve containers for the pod: %s", 
pod.metadata.name)
+        else:
+            if isinstance(container_logs, str):
+                # fetch logs only for requested container if only one 
container is provided
+                if container_logs in all_containers:
+                    status = self.fetch_container_logs(
+                        pod=pod, container_name=container_logs, 
follow=follow_logs
+                    )
+                    pod_logging_statuses.append(status)
+                else:
+                    self.log.error(
+                        "container %s whose logs were requested not found in 
the pod %s",
+                        container_logs,
+                        pod.metadata.name,
+                    )
+            elif isinstance(container_logs, bool):
+                # if True is provided, get logs for all the containers
+                if container_logs is True:
+                    for container_name in all_containers:
+                        status = self.fetch_container_logs(
+                            pod=pod, container_name=container_name, 
follow=follow_logs
+                        )
+                        pod_logging_statuses.append(status)
+                else:
+                    self.log.error(
+                        "False is not a valid value for container_logs",
+                    )
+            else:
+                # if a sequence of containers are provided, iterate for every 
container in the pod
+                if isinstance(container_logs, Iterable):
+                    for container in container_logs:
+                        if container in all_containers:
+                            status = self.fetch_container_logs(
+                                pod=pod, container_name=container, 
follow=follow_logs
+                            )
+                            pod_logging_statuses.append(status)
+                        else:
+                            self.log.error(
+                                "Container %s whose logs were requests not 
found in the pod %s",
+                                container,
+                                pod.metadata.name,
+                            )
+                else:
+                    self.log.error(
+                        "Invalid type %s specified for container names input 
parameter", type(container_logs)
+                    )
+
+        return pod_logging_statuses
+
     def await_container_completion(self, pod: V1Pod, container_name: str) -> 
None:
         """
         Waits for the given container in the given pod to be completed.
@@ -420,7 +492,12 @@ class PodManager(LoggingMixin):
         :param pod: pod spec that will be monitored
         :param container_name: name of the container within the pod to monitor
         """
-        while not self.container_is_terminated(pod=pod, 
container_name=container_name):
+        while True:
+            remote_pod = self.read_pod(pod)
+            terminated = container_is_completed(remote_pod, container_name)
+            if terminated:
+                break
+            self.log.info("Waiting for container '%s' state to be completed", 
container_name)
             time.sleep(1)
 
     def await_pod_completion(self, pod: V1Pod) -> V1Pod:
@@ -513,6 +590,16 @@ class PodManager(LoggingMixin):
             post_termination_timeout=post_termination_timeout,
         )
 
+    @tenacity.retry(stop=tenacity.stop_after_attempt(3), 
wait=tenacity.wait_exponential(), reraise=True)
+    def get_container_names(self, pod: V1Pod) -> list[str]:
+        """Return container names from the POD except for the 
airflow-xcom-sidecar container."""
+        pod_info = self.read_pod(pod)
+        return [
+            container_spec.name
+            for container_spec in pod_info.spec.containers
+            if container_spec.name != PodDefaults.SIDECAR_CONTAINER_NAME
+        ]
+
     @tenacity.retry(stop=tenacity.stop_after_attempt(3), 
wait=tenacity.wait_exponential(), reraise=True)
     def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
         """Reads events from the POD."""
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index e650932b53..2b66a3082b 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -500,7 +500,7 @@ class TestKubernetesPodOperatorSystem:
             )
             context = create_context(k)
             k.execute(context=context)
-            mock_logger.info.assert_any_call("retrieved from mount")
+            mock_logger.info.assert_any_call("[%s] %s", "base", "retrieved 
from mount")
             actual_pod = self.api_client.sanitize_for_serialization(k.pod)
             self.expected_pod["spec"]["containers"][0]["args"] = args
             self.expected_pod["spec"]["containers"][0]["volumeMounts"] = [
diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py 
b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
index 8f28d33dfd..c915a03a4a 100644
--- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
+++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py
@@ -317,6 +317,46 @@ class TestPodManager:
         assert ret.last_log_time is None
         assert ret.running is False
 
+    # adds all valid types for container_logs
+    @pytest.mark.parametrize("follow", [True, False])
+    @pytest.mark.parametrize("container_logs", ["base", "alpine", True, 
["base", "alpine"]])
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
+    def test_fetch_requested_container_logs(self, container_is_running, 
container_logs, follow):
+        mock_pod = MagicMock()
+        self.pod_manager.read_pod = MagicMock()
+        self.pod_manager.get_container_names = MagicMock()
+        self.pod_manager.get_container_names.return_value = ["base", "alpine"]
+        container_is_running.return_value = False
+        self.mock_kube_client.read_namespaced_pod_log.return_value = 
mock.MagicMock(
+            stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
+        )
+
+        ret_values = self.pod_manager.fetch_requested_container_logs(
+            pod=mock_pod, container_logs=container_logs, follow_logs=follow
+        )
+        for ret in ret_values:
+            assert ret.running is False
+
+    # adds all invalid types for container_logs
+    @pytest.mark.parametrize("container_logs", [1, None, 6.8, False])
+    
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
+    def test_fetch_requested_container_logs_invalid(self, container_running, 
container_logs):
+        mock_pod = MagicMock()
+        self.pod_manager.read_pod = MagicMock()
+        self.pod_manager.get_container_names = MagicMock()
+        self.pod_manager.get_container_names.return_value = ["base", "alpine"]
+        container_running.return_value = False
+        self.mock_kube_client.read_namespaced_pod_log.return_value = 
mock.MagicMock(
+            stream=mock.MagicMock(return_value=[b"2021-01-01 hi"])
+        )
+
+        ret_values = self.pod_manager.fetch_requested_container_logs(
+            pod=mock_pod,
+            container_logs=container_logs,
+        )
+
+        assert len(ret_values) == 0
+
     @mock.patch("pendulum.now")
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.container_is_running")
     
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodLogsConsumer.logs_available")

Reply via email to