This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 70b84b51a5 Allow setting the name for the base container within K8s 
Pod Operator (#28808)
70b84b51a5 is described below

commit 70b84b51a5802b72dc7a8fb9bf8133699adcc79c
Author: Charles Machalow <[email protected]>
AuthorDate: Mon Jan 23 11:46:52 2023 -0800

    Allow setting the name for the base container within K8s Pod Operator 
(#28808)
    
    Some downstream machinary may require a specific base container name
---
 .../cncf/kubernetes/operators/kubernetes_pod.py    |  18 ++-
 kubernetes_tests/test_kubernetes_pod_operator.py   | 132 +++++++++++++++++++++
 2 files changed, 145 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index c34fc5c02b..4bae3a9abf 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -167,7 +167,7 @@ class KubernetesPodOperator(BaseOperator):
         during the next try. If False, always create a new pod for each try.
     :param labels: labels to apply to the Pod. (templated)
     :param startup_timeout_seconds: timeout in seconds to startup the pod.
-    :param get_logs: get the stdout of the container as logs of the tasks.
+    :param get_logs: get the stdout of the base container as logs of the tasks.
     :param image_pull_policy: Specify a policy to cache or always pull an 
image.
     :param annotations: non-identifying metadata you can attach to the Pod.
         Can be a large range of data, and can include characters
@@ -206,9 +206,15 @@ class KubernetesPodOperator(BaseOperator):
         to populate the environment variables with. The contents of the target
         ConfigMap's Data field will represent the key-value pairs as 
environment variables.
         Extends env_from.
+    :param base_container_name: The name of the base container in the pod. 
This container's logs
+        will appear as part of this task's logs if get_logs is True. Defaults 
to None. If None,
+        will consult the class variable BASE_CONTAINER_NAME (which defaults to 
"base") for the base
+        container name to use.
     """
 
+    # This field can be overloaded at the instance level via 
base_container_name
     BASE_CONTAINER_NAME = "base"
+
     POD_CHECKED_KEY = "already_checked"
 
     template_fields: Sequence[str] = (
@@ -272,6 +278,7 @@ class KubernetesPodOperator(BaseOperator):
         pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None,
         termination_grace_period: int | None = None,
         configmaps: list[str] | None = None,
+        base_container_name: str | None = None,
         **kwargs,
     ) -> None:
         # TODO: remove in provider 6.0.0 release. This is a mitigate step to 
advise users to switch to the
@@ -338,6 +345,7 @@ class KubernetesPodOperator(BaseOperator):
         self.termination_grace_period = termination_grace_period
         self.pod_request_obj: k8s.V1Pod | None = None
         self.pod: k8s.V1Pod | None = None
+        self.base_container_name = base_container_name or 
self.BASE_CONTAINER_NAME
 
     @cached_property
     def _incluster_namespace(self):
@@ -498,12 +506,12 @@ class KubernetesPodOperator(BaseOperator):
             if self.get_logs:
                 self.pod_manager.fetch_container_logs(
                     pod=self.pod,
-                    container_name=self.BASE_CONTAINER_NAME,
+                    container_name=self.base_container_name,
                     follow=True,
                 )
             else:
                 self.pod_manager.await_container_completion(
-                    pod=self.pod, container_name=self.BASE_CONTAINER_NAME
+                    pod=self.pod, container_name=self.base_container_name
                 )
 
             if self.do_xcom_push:
@@ -535,7 +543,7 @@ class KubernetesPodOperator(BaseOperator):
             if self.log_events_on_failure:
                 self._read_pod_log_events(pod, reraise=False)
             self.process_pod_deletion(remote_pod, reraise=False)
-            error_message = get_container_termination_message(remote_pod, 
self.BASE_CONTAINER_NAME)
+            error_message = get_container_termination_message(remote_pod, 
self.base_container_name)
             raise AirflowException(
                 f"Pod {pod and pod.metadata.name} returned a 
failure:\n{error_message}\n"
                 f"remote_pod: {remote_pod}"
@@ -621,7 +629,7 @@ class KubernetesPodOperator(BaseOperator):
                 containers=[
                     k8s.V1Container(
                         image=self.image,
-                        name=self.BASE_CONTAINER_NAME,
+                        name=self.base_container_name,
                         command=self.cmds,
                         ports=self.ports,
                         image_pull_policy=self.image_pull_policy,
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index 52665b7343..b31fe024bb 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -1151,3 +1151,135 @@ class TestKubernetesPodOperatorSystem:
                 do_xcom_push=False,
                 resources=resources,
             )
+
+    def test_changing_base_container_name_with_get_logs(self):
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["echo 10"],
+            labels=self.labels,
+            task_id=str(uuid4()),
+            in_cluster=False,
+            do_xcom_push=False,
+            get_logs=True,
+            base_container_name="apple-sauce",
+        )
+        assert k.base_container_name == "apple-sauce"
+        context = create_context(k)
+        with mock.patch.object(
+            k.pod_manager, "fetch_container_logs", 
wraps=k.pod_manager.fetch_container_logs
+        ) as mock_fetch_container_logs:
+            k.execute(context)
+
+        assert mock_fetch_container_logs.call_args[1]["container_name"] == 
"apple-sauce"
+        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+        self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
+        assert self.expected_pod["spec"] == actual_pod["spec"]
+
+    def test_changing_base_container_name_no_logs(self):
+        """
+        This test checks BOTH a modified base container name AND the 
get_logs=False flow,
+        and as a result, also checks that the flow works with fast containers
+        See https://github.com/apache/airflow/issues/26796
+        """
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["echo 10"],
+            labels=self.labels,
+            task_id=str(uuid4()),
+            in_cluster=False,
+            do_xcom_push=False,
+            get_logs=False,
+            base_container_name="apple-sauce",
+        )
+        assert k.base_container_name == "apple-sauce"
+        context = create_context(k)
+        with mock.patch.object(
+            k.pod_manager, "await_container_completion", 
wraps=k.pod_manager.await_container_completion
+        ) as mock_await_container_completion:
+            k.execute(context)
+
+        assert mock_await_container_completion.call_args[1]["container_name"] 
== "apple-sauce"
+        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+        self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
+        assert self.expected_pod["spec"] == actual_pod["spec"]
+
+    def test_changing_base_container_name_no_logs_long(self):
+        """
+        Similar to test_changing_base_container_name_no_logs, but ensures that
+        pods running longer than 1 second work too.
+        See https://github.com/apache/airflow/issues/26796
+        """
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["bash", "-cx"],
+            arguments=["sleep 3"],
+            labels=self.labels,
+            task_id=str(uuid4()),
+            in_cluster=False,
+            do_xcom_push=False,
+            get_logs=False,
+            base_container_name="apple-sauce",
+        )
+        assert k.base_container_name == "apple-sauce"
+        context = create_context(k)
+        with mock.patch.object(
+            k.pod_manager, "await_container_completion", 
wraps=k.pod_manager.await_container_completion
+        ) as mock_await_container_completion:
+            k.execute(context)
+
+        assert mock_await_container_completion.call_args[1]["container_name"] 
== "apple-sauce"
+        actual_pod = self.api_client.sanitize_for_serialization(k.pod)
+        self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce"
+        self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"]
+        assert self.expected_pod["spec"] == actual_pod["spec"]
+
+    def test_changing_base_container_name_failure(self):
+        k = KubernetesPodOperator(
+            namespace="default",
+            image="ubuntu:16.04",
+            cmds=["exit"],
+            arguments=["1"],
+            labels=self.labels,
+            task_id=str(uuid4()),
+            in_cluster=False,
+            do_xcom_push=False,
+            base_container_name="apple-sauce",
+        )
+        assert k.base_container_name == "apple-sauce"
+        context = create_context(k)
+
+        class ShortCircuitException(Exception):
+            pass
+
+        with mock.patch(
+            
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod.get_container_termination_message",
+            side_effect=ShortCircuitException(),
+        ) as mock_get_container_termination_message:
+            with pytest.raises(ShortCircuitException):
+                k.execute(context)
+
+        assert mock_get_container_termination_message.call_args[0][1] == 
"apple-sauce"
+
+    def test_base_container_name_init_precedence(self):
+        assert (
+            KubernetesPodOperator(base_container_name="apple-sauce", 
task_id=str(uuid4())).base_container_name
+            == "apple-sauce"
+        )
+        assert (
+            KubernetesPodOperator(task_id=str(uuid4())).base_container_name
+            == KubernetesPodOperator.BASE_CONTAINER_NAME
+        )
+
+        class MyK8SPodOperator(KubernetesPodOperator):
+            BASE_CONTAINER_NAME = "tomato-sauce"
+
+        assert (
+            MyK8SPodOperator(base_container_name="apple-sauce", 
task_id=str(uuid4())).base_container_name
+            == "apple-sauce"
+        )
+        assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == 
"tomato-sauce"

Reply via email to