This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 900ad8c190 Fix: Configurable Docker image of `xcom_sidecar` (#32858)
900ad8c190 is described below

commit 900ad8c1907d3342ba1777ad99db37a0d3f5d61a
Author: pegasas <[email protected]>
AuthorDate: Sat Aug 5 02:16:25 2023 +0800

    Fix: Configurable Docker image of `xcom_sidecar` (#32858)
    
    * Configurable Docker image of xcom_sidecar
    
    * Update airflow/providers/cncf/kubernetes/utils/pod_manager.py
    
    * Update airflow/providers/cncf/kubernetes/utils/pod_manager.py
    
    * Update kubernetes.py
    
    ---------
    
    Co-authored-by: eladkal <[email protected]>
---
 .../providers/cncf/kubernetes/hooks/kubernetes.py  | 18 +++++++++
 airflow/providers/cncf/kubernetes/operators/pod.py |  6 ++-
 .../providers/cncf/kubernetes/utils/pod_manager.py |  6 +++
 kubernetes_tests/test_kubernetes_pod_operator.py   |  2 +
 .../cncf/kubernetes/decorators/test_kubernetes.py  | 11 +++++-
 .../cncf/kubernetes/hooks/test_kubernetes.py       | 46 ++++++++++++++++++++++
 6 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py 
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 56852fb1a2..ddb8cb27ad 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import contextlib
+import json
 import tempfile
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Generator
@@ -99,6 +100,12 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
             "cluster_context": StringField(lazy_gettext("Cluster context"), 
widget=BS3TextFieldWidget()),
             "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")),
             "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP 
keepalive")),
+            "xcom_sidecar_container_image": StringField(
+                lazy_gettext("XCom sidecar image"), widget=BS3TextFieldWidget()
+            ),
+            "xcom_sidecar_container_resources": StringField(
+                lazy_gettext("XCom sidecar resources (JSON format)"), 
widget=BS3TextFieldWidget()
+            ),
         }
 
     @staticmethod
@@ -356,6 +363,17 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
             return self._get_field("namespace")
         return None
 
+    def get_xcom_sidecar_container_image(self):
+        """Returns the xcom sidecar image that defined in the connection."""
+        return self._get_field("xcom_sidecar_container_image")
+
+    def get_xcom_sidecar_container_resources(self):
+        """Returns the xcom sidecar resources that defined in the 
connection."""
+        field = self._get_field("xcom_sidecar_container_resources")
+        if not field:
+            return None
+        return json.loads(field)
+
     def get_pod_log_stream(
         self,
         pod_name: str,
diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 28810b92ff..c707f9446f 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -878,7 +878,11 @@ class KubernetesPodOperator(BaseOperator):
             pod = secret.attach_to_pod(pod)
         if self.do_xcom_push:
             self.log.debug("Adding xcom sidecar to task %s", self.task_id)
-            pod = xcom_sidecar.add_xcom_sidecar(pod)
+            pod = xcom_sidecar.add_xcom_sidecar(
+                pod,
+                
sidecar_container_image=self.hook.get_xcom_sidecar_container_image(),
+                
sidecar_container_resources=self.hook.get_xcom_sidecar_container_resources(),
+            )
 
         labels = self._get_ti_pod_labels(context)
         self.log.info("Building pod %s with labels: %s", pod.metadata.name, 
labels)
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index c8ac74382d..81b6c1b2ca 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -101,6 +101,12 @@ class PodOperatorHookProtocol(Protocol):
     def get_namespace(self) -> str | None:
         """Returns the namespace that defined in the connection."""
 
+    def get_xcom_sidecar_container_image(self) -> str | None:
+        """Returns the xcom sidecar image that defined in the connection."""
+
+    def get_xcom_sidecar_container_resources(self) -> str | None:
+        """Returns the xcom sidecar resources that defined in the 
connection."""
+
 
 def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus 
| None:
     """Retrieves container status."""
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index 002394611d..7ba097f18d 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -897,6 +897,8 @@ class TestKubernetesPodOperatorSystem:
         # todo: This isn't really a system test
         await_xcom_sidecar_container_start_mock.return_value = None
         hook_mock.return_value.is_in_cluster = False
+        hook_mock.return_value.get_xcom_sidecar_container_image.return_value = 
None
+        
hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = None
         hook_mock.return_value.get_connection.return_value = 
Connection(conn_id="kubernetes_default")
         extract_xcom_mock.return_value = "{}"
         path = sys.path[0] + "/tests/providers/cncf/kubernetes/pod.yaml"
diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py 
b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
index 9bd7c06e40..b3ac936fda 100644
--- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py
@@ -30,6 +30,7 @@ DEFAULT_DATE = timezone.datetime(2021, 9, 1)
 KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.pod"
 POD_MANAGER_CLASS = 
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
 HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesHook"
+XCOM_IMAGE = "XCOM_IMAGE"
 
 
 @pytest.fixture(autouse=True)
@@ -122,6 +123,12 @@ def test_kubernetes_with_input_output(
 
         f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2", 
kwarg1="kwarg1")
 
+    mock_hook.return_value.get_xcom_sidecar_container_image.return_value = 
XCOM_IMAGE
+    mock_hook.return_value.get_xcom_sidecar_container_resources.return_value = 
{
+        "requests": {"cpu": "1m", "memory": "10Mi"},
+        "limits": {"cpu": "1m", "memory": "50Mi"},
+    }
+
     dr = dag_maker.create_dagrun()
     (ti,) = dr.task_instances
 
@@ -134,6 +141,8 @@ def test_kubernetes_with_input_output(
         config_file="/tmp/fake_file",
     )
     assert mock_create_pod.call_count == 1
+    assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count 
== 1
+    assert 
mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1
 
     containers = mock_create_pod.call_args.kwargs["pod"].spec.containers
 
@@ -152,7 +161,7 @@ def test_kubernetes_with_input_output(
     assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1": 
"kwarg1"}}
 
     # Second container is xcom image
-    assert containers[1].image == "alpine"
+    assert containers[1].image == XCOM_IMAGE
     assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
 
 
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py 
b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index ba151efaf7..7b5428481e 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -88,6 +88,20 @@ class TestKubernetesHook:
             ("disable_verify_ssl_empty", {"disable_verify_ssl": ""}),
             ("disable_tcp_keepalive", {"disable_tcp_keepalive": True}),
             ("disable_tcp_keepalive_empty", {"disable_tcp_keepalive": ""}),
+            ("sidecar_container_image", {"xcom_sidecar_container_image": 
"private.repo.com/alpine:3.16"}),
+            ("sidecar_container_image_empty", {"xcom_sidecar_container_image": 
""}),
+            (
+                "sidecar_container_resources",
+                {
+                    "xcom_sidecar_container_resources": json.dumps(
+                        {
+                            "requests": {"cpu": "1m", "memory": "10Mi"},
+                            "limits": {"cpu": "1m", "memory": "50Mi"},
+                        }
+                    ),
+                },
+            ),
+            ("sidecar_container_resources_empty", 
{"xcom_sidecar_container_resources": ""}),
         ]:
             db.merge_conn(Connection(conn_type="kubernetes", conn_id=conn_id, 
extra=json.dumps(extra)))
 
@@ -342,6 +356,38 @@ class TestKubernetesHook:
                 "and rename _get_namespace to get_namespace."
             )
 
+    @pytest.mark.parametrize(
+        "conn_id, expected",
+        (
+            pytest.param("sidecar_container_image", 
"private.repo.com/alpine:3.16", id="sidecar-with-image"),
+            pytest.param("sidecar_container_image_empty", None, 
id="sidecar-without-image"),
+        ),
+    )
+    def test_get_xcom_sidecar_container_image(self, conn_id, expected):
+        hook = KubernetesHook(conn_id=conn_id)
+        assert hook.get_xcom_sidecar_container_image() == expected
+
+    @pytest.mark.parametrize(
+        "conn_id, expected",
+        (
+            pytest.param(
+                "sidecar_container_resources",
+                {
+                    "requests": {"cpu": "1m", "memory": "10Mi"},
+                    "limits": {
+                        "cpu": "1m",
+                        "memory": "50Mi",
+                    },
+                },
+                id="sidecar-with-resources",
+            ),
+            pytest.param("sidecar_container_resources_empty", None, 
id="sidecar-without-resources"),
+        ),
+    )
+    def test_get_xcom_sidecar_container_resources(self, conn_id, expected):
+        hook = KubernetesHook(conn_id=conn_id)
+        assert hook.get_xcom_sidecar_container_resources() == expected
+
     @patch("kubernetes.config.kube_config.KubeConfigLoader")
     @patch("kubernetes.config.kube_config.KubeConfigMerger")
     def test_client_types(self, mock_kube_config_merger, 
mock_kube_config_loader):

Reply via email to