This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5326da4b83 Add `airflow_kpo_in_cluster` label to KPO pods (#24658)
5326da4b83 is described below

commit 5326da4b83ed4405553e88d5d5464508256498d0
Author: Jed Cunningham <[email protected]>
AuthorDate: Tue Jun 28 09:30:02 2022 -0600

    Add `airflow_kpo_in_cluster` label to KPO pods (#24658)
    
    This allows one to determine if the pod was created with in_cluster config
    or not, both on the k8s side and in pod_mutation_hooks.
---
 .../providers/cncf/kubernetes/hooks/kubernetes.py    | 15 +++++++++++++++
 .../cncf/kubernetes/operators/kubernetes_pod.py      |  9 +++++++--
 kubernetes_tests/test_kubernetes_pod_operator.py     | 15 +++++++++++----
 .../test_kubernetes_pod_operator_backcompat.py       |  2 ++
 .../cncf/kubernetes/hooks/test_kubernetes.py         |  7 +++++++
 .../cncf/kubernetes/operators/test_kubernetes_pod.py | 20 ++++++++++++++------
 6 files changed, 56 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py 
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index c4658ec8f3..ed794cd553 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -127,6 +127,8 @@ class KubernetesHook(BaseHook):
         self.disable_verify_ssl = disable_verify_ssl
         self.disable_tcp_keepalive = disable_tcp_keepalive
 
+        self._is_in_cluster: Optional[bool] = None
+
         # these params used for transition in KPO to K8s hook
         # for a deprecation period we will continue to consider k8s settings 
from airflow.cfg
         self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None
@@ -232,11 +234,13 @@ class KubernetesHook(BaseHook):
 
         if in_cluster:
             self.log.debug("loading kube_config from: in_cluster 
configuration")
+            self._is_in_cluster = True
             config.load_incluster_config()
             return client.ApiClient()
 
         if kubeconfig_path is not None:
             self.log.debug("loading kube_config from: %s", kubeconfig_path)
+            self._is_in_cluster = False
             config.load_kube_config(
                 config_file=kubeconfig_path,
                 client_configuration=self.client_configuration,
@@ -249,6 +253,7 @@ class KubernetesHook(BaseHook):
                 self.log.debug("loading kube_config from: connection 
kube_config")
                 temp_config.write(kubeconfig.encode())
                 temp_config.flush()
+                self._is_in_cluster = False
                 config.load_kube_config(
                     config_file=temp_config.name,
                     client_configuration=self.client_configuration,
@@ -265,14 +270,24 @@ class KubernetesHook(BaseHook):
         # in the default location
         try:
             
config.load_incluster_config(client_configuration=self.client_configuration)
+            self._is_in_cluster = True
         except ConfigException:
             self.log.debug("loading kube_config from: default file")
+            self._is_in_cluster = False
             config.load_kube_config(
                 client_configuration=self.client_configuration,
                 context=cluster_context,
             )
         return client.ApiClient()
 
+    @property
+    def is_in_cluster(self):
+        """Expose whether the hook is configured with 
``load_incluster_config`` or not"""
+        if self._is_in_cluster is not None:
+            return self._is_in_cluster
+        self.api_client  # so we can determine if we are in_cluster or not
+        return self._is_in_cluster
+
     @cached_property
     def api_client(self) -> Any:
         """Cached Kubernetes API client"""
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py 
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 1966b6bdb1..09cad504fe 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -342,6 +342,11 @@ class KubernetesPodOperator(BaseOperator):
         return PodManager(kube_client=self.client)
 
     def get_hook(self):
+        warnings.warn("get_hook is deprecated. Please use hook instead.", 
DeprecationWarning, stacklevel=2)
+        return self.hook
+
+    @cached_property
+    def hook(self) -> KubernetesHook:
         hook = KubernetesHook(
             conn_id=self.kubernetes_conn_id,
             in_cluster=self.in_cluster,
@@ -353,8 +358,7 @@ class KubernetesPodOperator(BaseOperator):
 
     @cached_property
     def client(self) -> CoreV1Api:
-        hook = self.get_hook()
-        return hook.core_v1_client
+        return self.hook.core_v1_client
 
     def find_pod(self, namespace, context, *, exclude_checked=True) -> 
Optional[k8s.V1Pod]:
         """Returns an already-running pod for this task instance if one 
exists."""
@@ -580,6 +584,7 @@ class KubernetesPodOperator(BaseOperator):
         pod.metadata.labels.update(
             {
                 'airflow_version': airflow_version.replace('+', '-'),
+                'airflow_kpo_in_cluster': str(self.hook.is_in_cluster),
             }
         )
         pod_mutation_hook(pod)
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py 
b/kubernetes_tests/test_kubernetes_pod_operator.py
index fb661e46b0..50e5978de7 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -93,6 +93,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
                     'foo': 'bar',
                     'kubernetes_pod_operator': 'True',
                     'airflow_version': airflow_version.replace('+', '-'),
+                    'airflow_kpo_in_cluster': 'False',
                     'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
                     'dag_id': 'dag',
                     'task_id': ANY,
@@ -734,6 +735,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
             'fizz': 'buzz',
             'foo': 'bar',
             'airflow_version': mock.ANY,
+            'airflow_kpo_in_cluster': 'False',
             'dag_id': 'dag',
             'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
             'kubernetes_pod_operator': 'True',
@@ -773,6 +775,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
             'fizz': 'buzz',
             'foo': 'bar',
             'airflow_version': mock.ANY,
+            'airflow_kpo_in_cluster': 'False',
             'dag_id': 'dag',
             'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
             'kubernetes_pod_operator': 'True',
@@ -815,6 +818,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
             'fizz': 'buzz',
             'foo': 'bar',
             'airflow_version': mock.ANY,
+            'airflow_kpo_in_cluster': 'False',
             'dag_id': 'dag',
             'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
             'kubernetes_pod_operator': 'True',
@@ -882,9 +886,10 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
     @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
     @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
     @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
-    @mock.patch(HOOK_CLASS, new=MagicMock)
-    def test_pod_template_file(self, await_pod_completion_mock, 
extract_xcom_mock):
+    @mock.patch(HOOK_CLASS)
+    def test_pod_template_file(self, hook_mock, await_pod_completion_mock, 
extract_xcom_mock):
         # todo: This isn't really a system test
+        hook_mock.return_value.is_in_cluster = False
         extract_xcom_mock.return_value = '{}'
         path = sys.path[0] + '/tests/kubernetes/pod.yaml'
         k = KubernetesPodOperator(
@@ -920,6 +925,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
             'metadata': {
                 'annotations': {},
                 'labels': {
+                    'airflow_kpo_in_cluster': 'False',
                     'dag_id': 'dag',
                     'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
                     'kubernetes_pod_operator': 'True',
@@ -968,13 +974,14 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
 
     @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
     @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
-    @mock.patch(HOOK_CLASS, new=MagicMock)
-    def test_pod_priority_class_name(self, await_pod_completion_mock):
+    @mock.patch(HOOK_CLASS)
+    def test_pod_priority_class_name(self, hook_mock, 
await_pod_completion_mock):
         """
         Test ability to assign priorityClassName to pod
 
         todo: This isn't really a system test
         """
+        hook_mock.return_value.is_in_cluster = False
 
         priority_class_name = "medium-test"
         k = KubernetesPodOperator(
diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py 
b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
index f15400edea..af2f0f38fe 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
@@ -90,6 +90,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
                     'foo': 'bar',
                     'kubernetes_pod_operator': 'True',
                     'airflow_version': airflow_version.replace('+', '-'),
+                    'airflow_kpo_in_cluster': 'False',
                     'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
                     'dag_id': 'dag',
                     'task_id': 'task',
@@ -571,6 +572,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
             'fizz': 'buzz',
             'foo': 'bar',
             'airflow_version': mock.ANY,
+            'airflow_kpo_in_cluster': 'False',
             'dag_id': 'dag',
             'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
             'kubernetes_pod_operator': 'True',
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py 
b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index 572f6e2890..6bbe5926e9 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -106,6 +106,11 @@ class TestKubernetesHook:
         else:
             mock_get_default_client.assert_called()
         assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
+        if mock_get_default_client.called:
+            # get_default_client sets it, but it's mocked
+            assert kubernetes_hook.is_in_cluster is None
+        else:
+            assert kubernetes_hook.is_in_cluster is in_cluster_called
 
     @pytest.mark.parametrize('in_cluster_fails', [True, False])
     @patch("kubernetes.config.kube_config.KubeConfigLoader")
@@ -130,10 +135,12 @@ class TestKubernetesHook:
             mock_incluster.assert_called_once()
             mock_merger.assert_called_once_with(KUBE_CONFIG_PATH)
             mock_loader.assert_called_once()
+            assert kubernetes_hook.is_in_cluster is False
         else:
             mock_incluster.assert_called_once()
             mock_merger.assert_not_called()
             mock_loader.assert_not_called()
+            assert kubernetes_hook.is_in_cluster is True
         assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
 
     @pytest.mark.parametrize(
diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py 
b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
index b771361d1a..88e0eeb07e 100644
--- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
@@ -100,6 +100,8 @@ class TestKubernetesPodOperator:
         remote_pod_mock = MagicMock()
         remote_pod_mock.status.phase = 'Succeeded'
         self.await_pod_mock.return_value = remote_pod_mock
+        if not isinstance(self.hook_mock.return_value.is_in_cluster, bool):
+            self.hook_mock.return_value.is_in_cluster = True
         operator.execute(context=context)
         return self.await_start_mock.call_args[1]['pod']
 
@@ -170,7 +172,9 @@ class TestKubernetesPodOperator:
         pod = self.run_pod(k)
         assert pod.spec.containers[0].env_from == env_from
 
-    def test_labels(self):
+    @pytest.mark.parametrize(("in_cluster",), ([True], [False]))
+    def test_labels(self, in_cluster):
+        self.hook_mock.return_value.is_in_cluster = in_cluster
         k = KubernetesPodOperator(
             namespace="default",
             image="ubuntu:16.04",
@@ -178,7 +182,7 @@ class TestKubernetesPodOperator:
             labels={"foo": "bar"},
             name="test",
             task_id="task",
-            in_cluster=False,
+            in_cluster=in_cluster,
             do_xcom_push=False,
         )
         pod = self.run_pod(k)
@@ -190,6 +194,7 @@ class TestKubernetesPodOperator:
             "try_number": "1",
             "airflow_version": mock.ANY,
             "run_id": "test",
+            "airflow_kpo_in_cluster": str(in_cluster),
         }
 
     def test_labels_mapped(self):
@@ -209,6 +214,7 @@ class TestKubernetesPodOperator:
             "airflow_version": mock.ANY,
             "run_id": "test",
             "map_index": "10",
+            "airflow_kpo_in_cluster": "True",
         }
 
     def test_find_pod_labels(self):
@@ -391,6 +397,7 @@ class TestKubernetesPodOperator:
             "task_id": "task",
             "try_number": "1",
             "airflow_version": mock.ANY,
+            "airflow_kpo_in_cluster": "True",
             "run_id": "test",
         }
 
@@ -429,6 +436,7 @@ class TestKubernetesPodOperator:
             "task_id": "task",
             "try_number": "1",
             "airflow_version": mock.ANY,
+            "airflow_kpo_in_cluster": "True",
             "run_id": "test",
         }
 
@@ -499,6 +507,7 @@ class TestKubernetesPodOperator:
             "task_id": "task",
             "try_number": "1",
             "airflow_version": mock.ANY,
+            "airflow_kpo_in_cluster": "True",
             "run_id": "test",
         }
         assert pod.metadata.namespace == "mynamespace"
@@ -568,6 +577,7 @@ class TestKubernetesPodOperator:
             "task_id": "task",
             "try_number": "1",
             "airflow_version": mock.ANY,
+            "airflow_kpo_in_cluster": "True",
             "run_id": "test",
         }
 
@@ -877,13 +887,11 @@ class TestKubernetesPodOperator:
         # the hook attr should be None
         op = KubernetesPodOperator(task_id='abc', name='hi')
         self.hook_patch.stop()
-        hook = op.get_hook()
-        assert getattr(hook, attr) is None
+        assert getattr(op.hook, attr) is None
         # now check behavior with a non-default value
         with conf_vars({('kubernetes', key): value}):
             op = KubernetesPodOperator(task_id='abc', name='hi')
-            hook = op.get_hook()
-            assert getattr(hook, attr) == patched_value
+            assert getattr(op.hook, attr) == patched_value
 
 
 def test__suppress():

Reply via email to