This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f81dfd731f Add `pod_template_dict` field to `KubernetesPodOperator` 
(#33174)
f81dfd731f is described below

commit f81dfd731f576121c0219c2601e06ecfa4ccc765
Author: Onur Sönmez <[email protected]>
AuthorDate: Sun Dec 17 22:13:49 2023 +0300

    Add `pod_template_dict` field to `KubernetesPodOperator` (#33174)
    
    * add pod_template_content field to kubernetes pod operator
    
    * add test for pod_template_content
    
    * fix
    
    * fix test
    
    * test labels
    
    * accept dictionary instead of yaml
    
    * fix staticcheck warnings
    
    * fix import error
    
    * change import order
    
    ---------
    
    Co-authored-by: Hussein Awala <[email protected]>
    Co-authored-by: Elad Kalif <[email protected]>
---
 airflow/providers/cncf/kubernetes/operators/pod.py |  9 ++++
 .../cncf/kubernetes/operators/test_pod.py          | 59 ++++++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py 
b/airflow/providers/cncf/kubernetes/operators/pod.py
index 55fe1e4c36..5153b85965 100644
--- a/airflow/providers/cncf/kubernetes/operators/pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/pod.py
@@ -218,6 +218,7 @@ class KubernetesPodOperator(BaseOperator):
         /airflow/xcom/return.json in the container will also be pushed to an
         XCom when the container completes.
     :param pod_template_file: path to pod template file (templated)
+    :param pod_template_dict: pod template dictionary (templated)
     :param priority_class_name: priority class name for the launched Pod
     :param pod_runtime_info_envs: (Optional) A list of environment variables,
         to be set in the container.
@@ -267,6 +268,7 @@ class KubernetesPodOperator(BaseOperator):
         "labels",
         "config_file",
         "pod_template_file",
+        "pod_template_dict",
         "namespace",
         "container_resources",
         "volumes",
@@ -322,6 +324,7 @@ class KubernetesPodOperator(BaseOperator):
         log_events_on_failure: bool = False,
         do_xcom_push: bool = False,
         pod_template_file: str | None = None,
+        pod_template_dict: dict | None = None,
         priority_class_name: str | None = None,
         pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None,
         termination_grace_period: int | None = None,
@@ -404,6 +407,7 @@ class KubernetesPodOperator(BaseOperator):
         self.log_events_on_failure = log_events_on_failure
         self.priority_class_name = priority_class_name
         self.pod_template_file = pod_template_file
+        self.pod_template_dict = pod_template_dict
         self.name = self._set_name(name)
         self.random_name_suffix = random_name_suffix
         self.termination_grace_period = termination_grace_period
@@ -897,6 +901,11 @@ class KubernetesPodOperator(BaseOperator):
             pod_template = 
pod_generator.PodGenerator.deserialize_model_file(self.pod_template_file)
             if self.full_pod_spec:
                 pod_template = PodGenerator.reconcile_pods(pod_template, 
self.full_pod_spec)
+        elif self.pod_template_dict:
+            self.log.debug("Pod template dict found, will parse for base pod")
+            pod_template = 
pod_generator.PodGenerator.deserialize_model_dict(self.pod_template_dict)
+            if self.full_pod_spec:
+                pod_template = PodGenerator.reconcile_pods(pod_template, 
self.full_pod_spec)
         elif self.full_pod_spec:
             pod_template = self.full_pod_spec
         else:
diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_pod.py
index 5e9cbbb916..8402f0f6b2 100644
--- a/tests/providers/cncf/kubernetes/operators/test_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_pod.py
@@ -30,6 +30,7 @@ from urllib3 import HTTPResponse
 from airflow.exceptions import AirflowException, AirflowSkipException, 
TaskDeferred
 from airflow.models import DAG, DagModel, DagRun, TaskInstance
 from airflow.models.xcom import XCom
+from airflow.providers.cncf.kubernetes import pod_generator
 from airflow.providers.cncf.kubernetes.operators.pod import 
KubernetesPodOperator, _optionally_suppress
 from airflow.providers.cncf.kubernetes.secret import Secret
 from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
@@ -1019,6 +1020,64 @@ class TestKubernetesPodOperator:
             "run_id": "test",
         }
 
+    @pytest.mark.parametrize(("randomize_name",), ([True], [False]))
+    def test_pod_template_dict(self, randomize_name):
+        templated_pod = k8s.V1Pod(
+            metadata=k8s.V1ObjectMeta(
+                namespace="templatenamespace",
+                name="hello",
+                labels={"release": "stable"},
+            ),
+            spec=k8s.V1PodSpec(
+                containers=[],
+                init_containers=[
+                    k8s.V1Container(
+                        name="git-clone",
+                        image="registry.k8s.io/git-sync:v3.1.1",
+                        args=[
+                            "[email protected]:airflow/some_repo.git",
+                            "--branch={{ params.get('repo_branch', 'master') 
}}",
+                        ],
+                    ),
+                ],
+            ),
+        )
+        k = KubernetesPodOperator(
+            task_id="task",
+            random_name_suffix=randomize_name,
+            
pod_template_dict=pod_generator.PodGenerator.serialize_pod(templated_pod),
+            labels={"hello": "world"},
+        )
+
+        # render templated fields before checking generated pod spec
+        k.render_template_fields(context={"params": {"repo_branch": 
"test_branch"}})
+        pod = k.build_pod_request_obj(create_context(k))
+
+        if randomize_name:
+            assert pod.metadata.name.startswith("hello")
+            assert pod.metadata.name != "hello"
+        else:
+            assert pod.metadata.name == "hello"
+
+        assert pod.metadata.labels == {
+            "hello": "world",
+            "release": "stable",
+            "dag_id": "dag",
+            "kubernetes_pod_operator": "True",
+            "task_id": "task",
+            "try_number": "1",
+            "airflow_version": mock.ANY,
+            "airflow_kpo_in_cluster": str(k.hook.is_in_cluster),
+            "run_id": "test",
+        }
+
+        assert pod.spec.init_containers[0].name == "git-clone"
+        assert pod.spec.init_containers[0].image == 
"registry.k8s.io/git-sync:v3.1.1"
+        assert pod.spec.init_containers[0].args == [
+            "[email protected]:airflow/some_repo.git",
+            "--branch=test_branch",
+        ]
+
     @patch(f"{POD_MANAGER_CLASS}.fetch_container_logs")
     @patch(f"{POD_MANAGER_CLASS}.await_container_completion", new=MagicMock)
     def test_no_handle_failure_on_success(self, fetch_container_mock):

Reply via email to