This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 60eb9e106f Use KubernetesHook to create api client in
KubernetesPodOperator (#20578)
60eb9e106f is described below
commit 60eb9e106f5915398eafd6aa339ec710c102dc09
Author: Daniel Standish <[email protected]>
AuthorDate: Tue May 31 11:37:42 2022 -0700
Use KubernetesHook to create api client in KubernetesPodOperator (#20578)
Add support for k8s hook in KPO; use it always (even when no conn id);
continue to consider the core k8s settings that KPO already takes into account
but emit deprecation warning about them.
KPO historically takes into account a few settings from core airflow cfg
(e.g. verify ssl, tcp keepalive, context, config file, and in_cluster). So to
use the hook to generate the client, somehow the hook has to take these
settings into account. But we don't want the hook to consider these settings in
general. So we read them in KPO and if necessary patch the hook and warn.
---
airflow/providers/cncf/kubernetes/CHANGELOG.rst | 13 +++
.../providers/cncf/kubernetes/hooks/kubernetes.py | 119 +++++++++++++++++--
.../cncf/kubernetes/operators/kubernetes_pod.py | 61 +++++++---
.../connections/kubernetes.rst | 13 ++-
kubernetes_tests/test_kubernetes_pod_operator.py | 67 ++++++-----
.../test_kubernetes_pod_operator_backcompat.py | 19 +--
.../cncf/kubernetes/hooks/test_kubernetes.py | 130 ++++++++++++++++++++-
.../kubernetes/operators/test_kubernetes_pod.py | 51 +++++---
8 files changed, 392 insertions(+), 81 deletions(-)
diff --git a/airflow/providers/cncf/kubernetes/CHANGELOG.rst
b/airflow/providers/cncf/kubernetes/CHANGELOG.rst
index 514035c8ca..bb2d236594 100644
--- a/airflow/providers/cncf/kubernetes/CHANGELOG.rst
+++ b/airflow/providers/cncf/kubernetes/CHANGELOG.rst
@@ -19,6 +19,19 @@
Changelog
---------
+main
+....
+
+Features
+~~~~~~~~
+
+KubernetesPodOperator now uses KubernetesHook
+`````````````````````````````````````````````
+
+Previously, KubernetesPodOperator relied on core Airflow configuration (namely
setting for kubernetes executor) for certain settings used in client
generation. Now KubernetesPodOperator uses KubernetesHook, and the
consideration of core k8s settings is officially deprecated.
+
+If you are using the Airflow configuration settings (e.g. as opposed to
operator params) to configure the kubernetes client, then prior to the next
major release you will need to add an Airflow connection and set your KPO tasks
to use that connection.
+
4.0.2
.....
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 5719918ce7..e15dce67ef 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -16,10 +16,13 @@
# under the License.
import sys
import tempfile
-from typing import Any, Dict, Generator, Optional, Tuple, Union
+import warnings
+from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from kubernetes.config import ConfigException
+from airflow.kubernetes.kube_client import _disable_verify_ssl,
_enable_tcp_keepalive
+
if sys.version_info >= (3, 8):
from functools import cached_property
else:
@@ -63,6 +66,14 @@ class KubernetesHook(BaseHook):
:param conn_id: The :ref:`kubernetes connection
<howto/connection:kubernetes>`
to Kubernetes cluster.
+ :param client_configuration: Optional dictionary of client configuration
params.
+ Passed on to kubernetes client.
+ :param cluster_context: Optionally specify a context to use (e.g. if you
have multiple
+ in your kubeconfig.
+ :param config_file: Path to kubeconfig file.
+ :param in_cluster: Set to ``True`` if running from within a kubernetes
cluster.
+ :param disable_verify_ssl: Set to ``True`` if SSL verification should be
disabled.
+ :param disable_tcp_keepalive: Set to ``True`` if you want to disable
keepalive logic.
"""
conn_name_attr = 'kubernetes_conn_id'
@@ -91,6 +102,8 @@ class KubernetesHook(BaseHook):
"extra__kubernetes__cluster_context": StringField(
lazy_gettext('Cluster context'), widget=BS3TextFieldWidget()
),
+ "extra__kubernetes__disable_verify_ssl":
BooleanField(lazy_gettext('Disable SSL')),
+ "extra__kubernetes__disable_tcp_keepalive":
BooleanField(lazy_gettext('Disable TCP keepalive')),
}
@staticmethod
@@ -108,6 +121,8 @@ class KubernetesHook(BaseHook):
cluster_context: Optional[str] = None,
config_file: Optional[str] = None,
in_cluster: Optional[bool] = None,
+ disable_verify_ssl: Optional[bool] = None,
+ disable_tcp_keepalive: Optional[bool] = None,
) -> None:
super().__init__()
self.conn_id = conn_id
@@ -115,6 +130,16 @@ class KubernetesHook(BaseHook):
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster
+ self.disable_verify_ssl = disable_verify_ssl
+ self.disable_tcp_keepalive = disable_tcp_keepalive
+
+ # these params used for transition in KPO to K8s hook
+ # for a deprecation period we will continue to consider k8s settings
from airflow.cfg
+ self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None
+ self._deprecated_core_disable_verify_ssl: Optional[bool] = None
+ self._deprecated_core_in_cluster: Optional[bool] = None
+ self._deprecated_core_cluster_context: Optional[str] = None
+ self._deprecated_core_config_file: Optional[str] = None
@staticmethod
def _coalesce_param(*params):
@@ -122,23 +147,51 @@ class KubernetesHook(BaseHook):
if param is not None:
return param
- def get_conn(self) -> Any:
- """Returns kubernetes api session for use with requests"""
+ @cached_property
+ def conn_extras(self):
if self.conn_id:
connection = self.get_connection(self.conn_id)
extras = connection.extra_dejson
else:
extras = {}
+ return extras
+
+ def _get_field(self, field_name):
+ if field_name.startswith('extra_'):
+ raise ValueError(
+ f"Got prefixed name {field_name}; please remove the
'extra__kubernetes__' prefix "
+ f"when using this method."
+ )
+ if field_name in self.conn_extras:
+ return self.conn_extras[field_name] or None
+ prefixed_name = f"extra__kubernetes__{field_name}"
+ return self.conn_extras.get(prefixed_name) or None
+
+ @staticmethod
+ def _deprecation_warning_core_param(deprecation_warnings):
+ settings_list_str = ''.join([f"\n\t{k}={v!r}" for k, v in
deprecation_warnings])
+ warnings.warn(
+ f"\nApplying core Airflow settings from section [kubernetes] with
the following keys:"
+ f"{settings_list_str}\n"
+ "In a future release, KubernetesPodOperator will no longer
consider core\n"
+ "Airflow settings; define an Airflow connection instead.",
+ DeprecationWarning,
+ )
+
+ def get_conn(self) -> Any:
+ """Returns kubernetes api session for use with requests"""
+
in_cluster = self._coalesce_param(
- self.in_cluster, extras.get("extra__kubernetes__in_cluster") or
None
+ self.in_cluster,
self.conn_extras.get("extra__kubernetes__in_cluster") or None
)
cluster_context = self._coalesce_param(
- self.cluster_context,
extras.get("extra__kubernetes__cluster_context") or None
+ self.cluster_context,
self.conn_extras.get("extra__kubernetes__cluster_context") or None
)
kubeconfig_path = self._coalesce_param(
- self.config_file,
extras.get("extra__kubernetes__kube_config_path") or None
+ self.config_file,
self.conn_extras.get("extra__kubernetes__kube_config_path") or None
)
- kubeconfig = extras.get("extra__kubernetes__kube_config") or None
+
+ kubeconfig = self.conn_extras.get("extra__kubernetes__kube_config") or
None
num_selected_configuration = len([o for o in [in_cluster, kubeconfig,
kubeconfig_path] if o])
if num_selected_configuration > 1:
@@ -147,6 +200,43 @@ class KubernetesHook(BaseHook):
"kube_config, in_cluster are mutually exclusive. "
"You can only use one option at a time."
)
+
+ disable_verify_ssl = self._coalesce_param(
+ self.disable_verify_ssl,
_get_bool(self._get_field("disable_verify_ssl"))
+ )
+ disable_tcp_keepalive = self._coalesce_param(
+ self.disable_tcp_keepalive,
_get_bool(self._get_field("disable_tcp_keepalive"))
+ )
+
+ # BEGIN apply settings from core kubernetes configuration
+ # this section should be removed in next major release
+ deprecation_warnings: List[Tuple[str, Any]] = []
+ if disable_verify_ssl is None and
self._deprecated_core_disable_verify_ssl is True:
+ deprecation_warnings.append(('verify_ssl', False))
+ disable_verify_ssl = self._deprecated_core_disable_verify_ssl
+ # by default, hook will try in_cluster first. so we only need to
+ # apply core airflow config and alert when False and in_cluster not
otherwise set.
+ if in_cluster is None and self._deprecated_core_in_cluster is False:
+ deprecation_warnings.append(('in_cluster',
self._deprecated_core_in_cluster))
+ in_cluster = self._deprecated_core_in_cluster
+ if not cluster_context and self._deprecated_core_cluster_context:
+ deprecation_warnings.append(('cluster_context',
self._deprecated_core_cluster_context))
+ cluster_context = self._deprecated_core_cluster_context
+ if not kubeconfig_path and self._deprecated_core_config_file:
+ deprecation_warnings.append(('config_file',
self._deprecated_core_config_file))
+ kubeconfig_path = self._deprecated_core_config_file
+ if disable_tcp_keepalive is None and
self._deprecated_core_disable_tcp_keepalive is True:
+ deprecation_warnings.append(('enable_tcp_keepalive', False))
+ disable_tcp_keepalive = True
+ if deprecation_warnings:
+ self._deprecation_warning_core_param(deprecation_warnings)
+ # END apply settings from core kubernetes configuration
+
+ if disable_verify_ssl is True:
+ _disable_verify_ssl()
+ if disable_tcp_keepalive is not True:
+ _enable_tcp_keepalive()
+
if in_cluster:
self.log.debug("loading kube_config from: in_cluster
configuration")
config.load_incluster_config()
@@ -316,3 +406,18 @@ class KubernetesHook(BaseHook):
_preload_content=False,
namespace=namespace if namespace else self.get_namespace(),
)
+
+
+def _get_bool(val) -> Optional[bool]:
+ """
+ Converts val to bool if can be done with certainty.
+ If we cannot infer intention we return None.
+ """
+ if isinstance(val, bool):
+ return val
+ elif isinstance(val, str):
+ if val.strip().lower() == 'true':
+ return True
+ elif val.strip().lower() == 'false':
+ return False
+ return None
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index fba88f80e1..69f120e823 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -25,8 +25,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional,
Sequence
from kubernetes.client import CoreV1Api, models as k8s
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.kubernetes import kube_client, pod_generator
+from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.kubernetes.secret import Secret
from airflow.models import BaseOperator
@@ -142,6 +143,7 @@ class KubernetesPodOperator(BaseOperator):
:param priority_class_name: priority class name for the launched Pod
:param termination_grace_period: Termination grace period if task killed
in UI,
defaults to kubernetes default
+ :param: kubernetes_conn_id: To retrieve credentials for your k8s cluster
from an Airflow connection
"""
BASE_CONTAINER_NAME = 'base'
@@ -209,7 +211,6 @@ class KubernetesPodOperator(BaseOperator):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use
'do_xcom_push' instead")
super().__init__(resources=None, **kwargs)
-
self.kubernetes_conn_id = kubernetes_conn_id
self.do_xcom_push = do_xcom_push
self.image = image
@@ -324,19 +325,20 @@ class KubernetesPodOperator(BaseOperator):
def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)
- @cached_property
- def client(self) -> CoreV1Api:
- if self.kubernetes_conn_id:
- hook = KubernetesHook(conn_id=self.kubernetes_conn_id)
- return hook.core_v1_client
-
- kwargs: Dict[str, Any] = dict(
- cluster_context=self.cluster_context,
+ def get_hook(self):
+ hook = KubernetesHook(
+ conn_id=self.kubernetes_conn_id,
+ in_cluster=self.in_cluster,
config_file=self.config_file,
+ cluster_context=self.cluster_context,
)
- if self.in_cluster is not None:
- kwargs.update(in_cluster=self.in_cluster)
- return kube_client.get_kube_client(**kwargs)
+ self._patch_deprecated_k8s_settings(hook)
+ return hook
+
+ @cached_property
+ def client(self) -> CoreV1Api:
+ hook = self.get_hook()
+ return hook.core_v1_client
def find_pod(self, namespace, context, *, exclude_checked=True) ->
Optional[k8s.V1Pod]:
"""Returns an already-running pod for this task instance if one
exists."""
@@ -573,6 +575,39 @@ class KubernetesPodOperator(BaseOperator):
pod = self.build_pod_request_obj()
print(yaml.dump(prune_dict(pod.to_dict(), mode='strict')))
+ def _patch_deprecated_k8s_settings(self, hook: KubernetesHook):
+ """
+ Here we read config from core Airflow config [kubernetes] section.
+ In a future release we will stop looking at this section and require
users
+ to use Airflow connections to configure KPO.
+
+ When we find values there that we need to apply on the hook, we patch
special
+ hook attributes here.
+ """
+
+ # default for enable_tcp_keepalive is True; patch if False
+ if conf.getboolean('kubernetes', 'enable_tcp_keepalive') is False:
+ hook._deprecated_core_disable_tcp_keepalive = True
+
+ # default verify_ssl is True; patch if False.
+ if conf.getboolean('kubernetes', 'verify_ssl') is False:
+ hook._deprecated_core_disable_verify_ssl = True
+
+ # default for in_cluster is True; patch if False and no KPO param.
+ conf_in_cluster = conf.getboolean('kubernetes', 'in_cluster')
+ if self.in_cluster is None and conf_in_cluster is False:
+ hook._deprecated_core_in_cluster = conf_in_cluster
+
+ # there's no default for cluster context; if we get something (and no
KPO param) patch it.
+ conf_cluster_context = conf.get('kubernetes', 'cluster_context',
fallback=None)
+ if not self.cluster_context and conf_cluster_context:
+ hook._deprecated_core_cluster_context = conf_cluster_context
+
+ # there's no default for config_file; if we get something (and no KPO
param) patch it.
+ conf_config_file = conf.get('kubernetes', 'config_file', fallback=None)
+ if not self.config_file and conf_config_file:
+ hook._deprecated_core_config_file = conf_config_file
+
class _suppress(AbstractContextManager):
"""
diff --git
a/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst
b/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst
index c0a8539832..3d0dca268e 100644
--- a/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst
+++ b/docs/apache-airflow-providers-cncf-kubernetes/connections/kubernetes.rst
@@ -58,12 +58,17 @@ Kube config (JSON format)
Namespace
Default Kubernetes namespace for the connection.
-When specifying the connection in environment variable you should specify
-it using URI syntax.
+Cluster context
+ When using a kube config, can specify which context to use.
-Note that all components of the URI should be URL-encoded.
+Disable verify SSL
+ Can optionally disable SSL certificate verification. By default SSL is
verified.
-For example:
+Disable TCP keepalive
+ TCP keepalive is a feature (enabled by default) that tries to keep
long-running connections
+ alive. Set this parameter to True to disable this feature.
+
+Example storing connection in env var using URI format:
.. code-block:: bash
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py
b/kubernetes_tests/test_kubernetes_pod_operator.py
index 56705c7a9d..4992827451 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -33,9 +33,9 @@ from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from airflow.exceptions import AirflowException
-from airflow.kubernetes import kube_client
from airflow.kubernetes.secret import Secret
from airflow.models import DAG, XCOM_RETURN_KEY, DagRun, TaskInstance
+from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import
KubernetesPodOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -43,6 +43,9 @@ from airflow.utils import timezone
from airflow.utils.types import DagRunType
from airflow.version import version as airflow_version
+HOOK_CLASS =
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
+POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
+
def create_context(task):
dag = DAG(dag_id="dag")
@@ -123,7 +126,8 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
}
def tearDown(self) -> None:
- client = kube_client.get_kube_client(in_cluster=False)
+ hook = KubernetesHook(conn_id=None, in_cluster=False)
+ client = hook.core_v1_client
client.delete_collection_namespaced_pod(namespace="default")
import time
@@ -632,10 +636,12 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
self.expected_pod['spec']['containers'].append(container)
assert self.expected_pod == actual_pod
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_secrets(self, mock_client, await_pod_completion_mock,
create_pod):
+ @mock.patch(f"{POD_MANAGER_CLASS}.create_pod")
+ @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_envs_from_secrets(self, await_pod_completion_mock, create_pod):
+ # todo: This isn't really a system test
+
# GIVEN
secret_ref = 'secret_name'
@@ -696,6 +702,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
assert self.expected_pod == actual_pod
def test_pod_template_file_system(self):
+ """Note: this test requires that you have a namespace ``mem-example``
in your cluster."""
fixture = sys.path[0] + '/tests/kubernetes/basic_pod.yaml'
k = KubernetesPodOperator(
task_id="task" + self.get_current_task_name(),
@@ -872,11 +879,12 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
]
assert self.expected_pod == actual_pod
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_pod_template_file(self, mock_client, await_pod_completion_mock,
create_mock, extract_xcom_mock):
+ @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
+ @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
+ @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_pod_template_file(self, await_pod_completion_mock,
extract_xcom_mock):
+ # todo: This isn't really a system test
extract_xcom_mock.return_value = '{}'
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
k = KubernetesPodOperator(
@@ -958,11 +966,15 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
del actual_pod['metadata']['labels']['airflow_version']
assert expected_dict == actual_pod
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_pod_priority_class_name(self, mock_client,
await_pod_completion_mock, create_mock):
- """Test ability to assign priorityClassName to pod"""
+ @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
+ @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_pod_priority_class_name(self, await_pod_completion_mock):
+ """
+ Test ability to assign priorityClassName to pod
+
+ todo: This isn't really a system test
+ """
priority_class_name = "medium-test"
k = KubernetesPodOperator(
@@ -1002,10 +1014,10 @@ class
TestKubernetesPodOperatorSystem(unittest.TestCase):
do_xcom_push=False,
)
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- def test_on_kill(self, await_pod_completion_mock):
-
- client = kube_client.get_kube_client(in_cluster=False)
+ @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion", new=MagicMock)
+ def test_on_kill(self):
+ hook = KubernetesHook(conn_id=None, in_cluster=False)
+ client = hook.core_v1_client
name = "test"
namespace = "default"
k = KubernetesPodOperator(
@@ -1032,7 +1044,8 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
client.read_namespaced_pod(name=name, namespace=namespace)
def test_reattach_failing_pod_once(self):
- client = kube_client.get_kube_client(in_cluster=False)
+ hook = KubernetesHook(conn_id=None, in_cluster=False)
+ client = hook.core_v1_client
name = "test"
namespace = "default"
@@ -1056,9 +1069,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
context = create_context(k)
# launch pod
- with mock.patch(
-
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion"
- ) as await_pod_completion_mock:
+ with mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") as
await_pod_completion_mock:
pod_mock = MagicMock()
pod_mock.status.phase = 'Succeeded'
@@ -1082,9 +1093,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
# `create_pod` should not be called because there's a pod there it
should find
# should use the found pod and patch as "already_checked" (in failure
block)
- with mock.patch(
-
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
- ) as create_mock:
+ with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock:
with pytest.raises(AirflowException):
k.execute(context)
pod = client.read_namespaced_pod(name=name, namespace=namespace)
@@ -1096,9 +1105,7 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
# `create_pod` should be called because though there's still a pod to
be found,
# it will be `already_checked`
- with mock.patch(
-
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod"
- ) as create_mock:
+ with mock.patch(f"{POD_MANAGER_CLASS}.create_pod") as create_mock:
with pytest.raises(AirflowException):
k.execute(context)
create_mock.assert_called_once()
diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
index 21a3bda7d3..5a4efc73d4 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
@@ -29,13 +29,13 @@ from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from airflow.exceptions import AirflowException
-from airflow.kubernetes import kube_client
from airflow.kubernetes.pod import Port
from airflow.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
from airflow.kubernetes.secret import Secret
from airflow.kubernetes.volume import Volume
from airflow.kubernetes.volume_mount import VolumeMount
from airflow.models import DAG, DagRun, TaskInstance
+from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import
KubernetesPodOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -45,6 +45,8 @@ from airflow.version import version as airflow_version
# noinspection DuplicatedCode
+HOOK_CLASS =
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
+
def create_context(task):
dag = DAG(dag_id="dag")
@@ -121,13 +123,14 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
}
def tearDown(self):
- client = kube_client.get_kube_client(in_cluster=False)
+ hook = KubernetesHook(conn_id=None, in_cluster=False)
+ client = hook.core_v1_client
client.delete_collection_namespaced_pod(namespace="default")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_image_pull_secrets_correctly_set(self, mock_client,
await_pod_completion_mock, create_mock):
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_image_pull_secrets_correctly_set(self, await_pod_completion_mock,
create_mock):
fake_pull_secrets = "fakeSecret"
k = KubernetesPodOperator(
namespace='default',
@@ -461,8 +464,8 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_configmaps(self, mock_client, mock_monitor, mock_start):
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_envs_from_configmaps(self, mock_monitor, mock_start):
# GIVEN
configmap = 'test-configmap'
# WHEN
@@ -490,8 +493,8 @@ class TestKubernetesPodOperatorSystem(unittest.TestCase):
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.create_pod")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion")
- @mock.patch("airflow.kubernetes.kube_client.get_kube_client")
- def test_envs_from_secrets(self, mock_client, await_pod_completion_mock,
create_mock):
+ @mock.patch(HOOK_CLASS, new=MagicMock)
+ def test_envs_from_secrets(self, await_pod_completion_mock, create_mock):
# GIVEN
secret_ref = 'secret_name'
secrets = [Secret('env', None, secret_ref)]
diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
index 5194061e7a..572f6e2890 100644
--- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
+++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
@@ -21,7 +21,7 @@ import json
import os
import tempfile
from unittest import mock
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import kubernetes
import pytest
@@ -34,6 +34,7 @@ from airflow.utils import db
from tests.test_utils.db import clear_db_connections
KUBE_CONFIG_PATH = os.getenv('KUBECONFIG', '~/.kube/config')
+HOOK_MODULE = "airflow.providers.cncf.kubernetes.hooks.kubernetes"
class TestKubernetesHook:
@@ -41,9 +42,9 @@ class TestKubernetesHook:
def setup_class(cls) -> None:
for conn_id, extra in [
('in_cluster', {'extra__kubernetes__in_cluster': True}),
+ ('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}),
('kube_config', {'extra__kubernetes__kube_config': '{"test":
"kube"}'}),
('kube_config_path', {'extra__kubernetes__kube_config_path':
'path/to/file'}),
- ('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}),
('kube_config_empty', {'extra__kubernetes__kube_config': ''}),
('kube_config_path_empty', {'extra__kubernetes__kube_config_path':
''}),
('kube_config_empty', {'extra__kubernetes__kube_config': ''}),
@@ -52,6 +53,10 @@ class TestKubernetesHook:
('context', {'extra__kubernetes__cluster_context': 'my-context'}),
('with_namespace', {'extra__kubernetes__namespace':
'mock_namespace'}),
('default_kube_config', {}),
+ ('disable_verify_ssl', {'extra__kubernetes__disable_verify_ssl':
True}),
+ ('disable_verify_ssl_empty',
{'extra__kubernetes__disable_verify_ssl': ''}),
+ ('disable_tcp_keepalive',
{'extra__kubernetes__disable_tcp_keepalive': True}),
+ ('disable_tcp_keepalive_empty',
{'extra__kubernetes__disable_tcp_keepalive': ''}),
]:
db.merge_conn(Connection(conn_type='kubernetes', conn_id=conn_id,
extra=json.dumps(extra)))
@@ -76,7 +81,7 @@ class TestKubernetesHook:
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch("kubernetes.config.incluster_config.InClusterConfigLoader")
-
@patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client")
+ @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client")
def test_in_cluster_connection(
self,
mock_get_default_client,
@@ -131,6 +136,70 @@ class TestKubernetesHook:
mock_loader.assert_not_called()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
+ @pytest.mark.parametrize(
+ 'disable_verify_ssl, conn_id, disable_called',
+ (
+ (True, None, True),
+ (None, None, False),
+ (False, None, False),
+ (None, 'disable_verify_ssl', True),
+ (True, 'disable_verify_ssl', True),
+ (False, 'disable_verify_ssl', False),
+ (None, 'disable_verify_ssl_empty', False),
+ (True, 'disable_verify_ssl_empty', True),
+ (False, 'disable_verify_ssl_empty', False),
+ ),
+ )
+ @patch("kubernetes.config.incluster_config.InClusterConfigLoader",
new=MagicMock())
+ @patch(f"{HOOK_MODULE}._disable_verify_ssl")
+ def test_disable_verify_ssl(
+ self,
+ mock_disable,
+ disable_verify_ssl,
+ conn_id,
+ disable_called,
+ ):
+ """
+ Verifies whether disable verify ssl is called depending on combination
of hook param and
+ connection extra. Hook param should beat extra.
+ """
+ kubernetes_hook = KubernetesHook(conn_id=conn_id,
disable_verify_ssl=disable_verify_ssl)
+ api_conn = kubernetes_hook.get_conn()
+ assert mock_disable.called is disable_called
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
+
+ @pytest.mark.parametrize(
+ 'disable_tcp_keepalive, conn_id, expected',
+ (
+ (True, None, False),
+ (None, None, True),
+ (False, None, True),
+ (None, 'disable_tcp_keepalive', False),
+ (True, 'disable_tcp_keepalive', False),
+ (False, 'disable_tcp_keepalive', True),
+ (None, 'disable_tcp_keepalive_empty', True),
+ (True, 'disable_tcp_keepalive_empty', False),
+ (False, 'disable_tcp_keepalive_empty', True),
+ ),
+ )
+ @patch("kubernetes.config.incluster_config.InClusterConfigLoader",
new=MagicMock())
+ @patch(f"{HOOK_MODULE}._enable_tcp_keepalive")
+ def test_disable_tcp_keepalive(
+ self,
+ mock_enable,
+ disable_tcp_keepalive,
+ conn_id,
+ expected,
+ ):
+ """
+ Verifies whether enable tcp keepalive is called depending on
combination of hook
+ param and connection extra. Hook param should beat extra.
+ """
+ kubernetes_hook = KubernetesHook(conn_id=conn_id,
disable_tcp_keepalive=disable_tcp_keepalive)
+ api_conn = kubernetes_hook.get_conn()
+ assert mock_enable.called is expected
+ assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
+
@pytest.mark.parametrize(
'config_path_param, conn_id, call_path',
(
@@ -239,6 +308,61 @@ class TestKubernetesHook:
assert isinstance(hook.api_client, kubernetes.client.ApiClient)
assert isinstance(hook.get_conn(), kubernetes.client.ApiClient)
+ @patch(f"{HOOK_MODULE}._disable_verify_ssl")
+ @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client", new=MagicMock)
+ def test_patch_core_settings_verify_ssl(self, mock_disable_verify_ssl):
+ hook = KubernetesHook()
+ hook.get_conn()
+ mock_disable_verify_ssl.assert_not_called()
+ mock_disable_verify_ssl.reset_mock()
+ hook._deprecated_core_disable_verify_ssl = True
+ hook.get_conn()
+ mock_disable_verify_ssl.assert_called()
+
+ @patch(f"{HOOK_MODULE}._enable_tcp_keepalive")
+ @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client", new=MagicMock)
+ def test_patch_core_settings_tcp_keepalive(self,
mock_enable_tcp_keepalive):
+ hook = KubernetesHook()
+ hook.get_conn()
+ mock_enable_tcp_keepalive.assert_called()
+ mock_enable_tcp_keepalive.reset_mock()
+ hook._deprecated_core_disable_tcp_keepalive = True
+ hook.get_conn()
+ mock_enable_tcp_keepalive.assert_not_called()
+
+ @patch("kubernetes.config.kube_config.KubeConfigLoader", new=MagicMock())
+ @patch("kubernetes.config.kube_config.KubeConfigMerger", new=MagicMock())
+ @patch("kubernetes.config.incluster_config.InClusterConfigLoader")
+ @patch(f"{HOOK_MODULE}.KubernetesHook._get_default_client")
+ def test_patch_core_settings_in_cluster(self, mock_get_default_client,
mock_in_cluster_loader):
+ hook = KubernetesHook(conn_id=None)
+ hook.get_conn()
+ mock_in_cluster_loader.assert_not_called()
+ mock_in_cluster_loader.reset_mock()
+ hook._deprecated_core_in_cluster = False
+ hook.get_conn()
+ mock_in_cluster_loader.assert_not_called()
+ mock_get_default_client.assert_called()
+
+ @pytest.mark.parametrize(
+ 'key, key_val, attr, attr_val',
+ [
+ ('in_cluster', False, '_deprecated_core_in_cluster', False),
+ ('verify_ssl', False, '_deprecated_core_disable_verify_ssl', True),
+ ('cluster_context', 'hi', '_deprecated_core_cluster_context',
'hi'),
+ ('config_file', '/path/to/file.txt',
'_deprecated_core_config_file', '/path/to/file.txt'),
+ ('enable_tcp_keepalive', False,
'_deprecated_core_disable_tcp_keepalive', True),
+ ],
+ )
+ @patch("kubernetes.config.incluster_config.InClusterConfigLoader",
new=MagicMock())
+ @patch("kubernetes.config.kube_config.KubeConfigLoader", new=MagicMock())
+ @patch("kubernetes.config.kube_config.KubeConfigMerger", new=MagicMock())
+ def test_core_settings_warnings(self, key, key_val, attr, attr_val):
+ hook = KubernetesHook(conn_id=None)
+ setattr(hook, attr, attr_val)
+ with pytest.warns(DeprecationWarning, match=rf'.*Airflow
settings.*\n.*{key}={key_val!r}.*'):
+ hook.get_conn()
+
class TestKubernetesHookIncorrectConfiguration:
@pytest.mark.parametrize(
diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
index e70bf88326..cac1b8a1b5 100644
--- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
+++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
@@ -29,9 +29,12 @@ from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
from tests.test_utils import db
+from tests.test_utils.config import conf_vars
DEFAULT_DATE = timezone.datetime(2016, 1, 1, 1, 0, 0)
KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.kubernetes_pod"
+POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
+HOOK_CLASS =
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesHook"
@pytest.fixture(scope='function', autouse=True)
@@ -66,28 +69,22 @@ def create_context(task, persist_to_db=False):
}
-POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
-
-
class TestKubernetesPodOperator:
@pytest.fixture(autouse=True)
def setup(self, dag_maker):
self.create_pod_patch = mock.patch(f"{POD_MANAGER_CLASS}.create_pod")
self.await_pod_patch =
mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start")
self.await_pod_completion_patch =
mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
- self.client_patch =
mock.patch("airflow.kubernetes.kube_client.get_kube_client")
+ self.hook_patch = mock.patch(HOOK_CLASS)
self.create_mock = self.create_pod_patch.start()
self.await_start_mock = self.await_pod_patch.start()
self.await_pod_mock = self.await_pod_completion_patch.start()
- self.client_mock = self.client_patch.start()
+ self.hook_mock = self.hook_patch.start()
self.dag_maker = dag_maker
yield
- self.create_pod_patch.stop()
- self.await_pod_patch.stop()
- self.await_pod_completion_patch.stop()
- self.client_patch.stop()
+ mock.patch.stopall()
def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) ->
k8s.V1Pod:
with self.dag_maker(dag_id='dag') as dag:
@@ -127,9 +124,9 @@ class TestKubernetesPodOperator:
remote_pod_mock = MagicMock()
remote_pod_mock.status.phase = 'Succeeded'
self.await_pod_mock.return_value = remote_pod_mock
- self.client_mock.list_namespaced_pod.return_value = []
self.run_pod(k)
- self.client_mock.assert_called_once_with(
+ self.hook_mock.assert_called_once_with(
+ conn_id=None,
in_cluster=False,
cluster_context="default",
config_file=file_path,
@@ -226,8 +223,7 @@ class TestKubernetesPodOperator:
do_xcom_push=False,
)
self.run_pod(k)
- self.client_mock.return_value.list_namespaced_pod.assert_called_once()
- _, kwargs = self.client_mock.return_value.list_namespaced_pod.call_args
+ _, kwargs = k.client.list_namespaced_pod.call_args
assert kwargs['label_selector'] == (
'dag_id=dag,kubernetes_pod_operator=True,run_id=test,task_id=task,'
'already_checked!=True,!airflow-worker'
@@ -576,7 +572,7 @@ class TestKubernetesPodOperator:
context = create_context(k)
k.execute(context=context)
- assert not self.client_mock.return_value.read_namespaced_pod.called
+ assert k.client.read_namespaced_pod.called is False
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_container_completion")
@@ -794,8 +790,8 @@ class TestKubernetesPodOperator:
task_id="task",
)
self.run_pod(k)
- self.client_mock.return_value.list_namespaced_pod.assert_called_once()
- _, kwargs = self.client_mock.return_value.list_namespaced_pod.call_args
+ k.client.list_namespaced_pod.assert_called_once()
+ _, kwargs = k.client.list_namespaced_pod.call_args
assert 'already_checked!=True' in kwargs['label_selector']
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.delete_pod")
@@ -842,6 +838,29 @@ class TestKubernetesPodOperator:
mock_patch_already_checked.assert_called_once()
mock_delete_pod.assert_not_called()
+ @pytest.mark.parametrize(
+ 'key, value, attr, patched_value',
+ [
+ ('verify_ssl', 'False', '_deprecated_core_disable_verify_ssl',
True),
+ ('in_cluster', 'False', '_deprecated_core_in_cluster', False),
+ ('cluster_context', 'hi', '_deprecated_core_cluster_context',
'hi'),
+ ('config_file', '/path/to/file.txt',
'_deprecated_core_config_file', '/path/to/file.txt'),
+ ('enable_tcp_keepalive', 'False',
'_deprecated_core_disable_tcp_keepalive', True),
+ ],
+ )
+ def test_patch_core_settings(self, key, value, attr, patched_value):
+ # first verify the behavior for the default value
+ # the hook attr should be None
+ op = KubernetesPodOperator(task_id='abc', name='hi')
+ self.hook_patch.stop()
+ hook = op.get_hook()
+ assert getattr(hook, attr) is None
+ # now check behavior with a non-default value
+ with conf_vars({('kubernetes', key): value}):
+ op = KubernetesPodOperator(task_id='abc', name='hi')
+ hook = op.get_hook()
+ assert getattr(hook, attr) == patched_value
+
def test__suppress():
with mock.patch('logging.Logger.error') as mock_error: