This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bb2cc419507 KubernetesPodOperator: Rework of Kubernetes API retry
behavior (#58397)
bb2cc419507 is described below
commit bb2cc4195076af1c66377c5b54c70e8ee4fab6e5
Author: AutomationDev85 <[email protected]>
AuthorDate: Sat Nov 22 23:14:34 2025 +0100
KubernetesPodOperator: Rework of Kubernetes API retry behavior (#58397)
* Move retry handling to the hook layer and update PodManager accordingly
* Removed overlapping code
* Clean up code
* Detailed logging and use of autouse fixture
* move no wait fixture into conftest
* Disabled no_retry_wait patch for explicitly marked unit tests.
* Fix unit test
* Generic retry logic can handle async and sync kubernetes api exceptions
---------
Co-authored-by: AutomationDev85 <AutomationDev85>
---
.../providers/cncf/kubernetes/hooks/kubernetes.py | 22 ++++--
.../cncf/kubernetes/kubernetes_helper_functions.py | 77 ++++++++++++++----
.../cncf/kubernetes/operators/resource.py | 10 +--
.../providers/cncf/kubernetes/utils/pod_manager.py | 90 +++-------------------
.../tests/unit/cncf/kubernetes/conftest.py | 22 ++++++
.../unit/cncf/kubernetes/hooks/test_kubernetes.py | 5 +-
.../cncf/kubernetes/operators/test_resource.py | 5 +-
.../kubernetes/test_kubernetes_helper_functions.py | 61 ++++++++++++++-
.../unit/cncf/kubernetes/utils/test_pod_manager.py | 47 +----------
9 files changed, 175 insertions(+), 164 deletions(-)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 0c06c3d29d9..0e105d241b1 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -27,7 +27,6 @@ from typing import TYPE_CHECKING, Any, Protocol
import aiofiles
import requests
-import tenacity
from asgiref.sync import sync_to_async
from kubernetes import client, config, utils, watch
from kubernetes.client.models import V1Deployment
@@ -39,7 +38,7 @@ from airflow.exceptions import AirflowException,
AirflowNotFoundException
from airflow.models import Connection
from airflow.providers.cncf.kubernetes.exceptions import KubernetesApiError,
KubernetesApiPermissionError
from airflow.providers.cncf.kubernetes.kube_client import _disable_verify_ssl,
_enable_tcp_keepalive
-from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
should_retry_creation
+from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
generic_api_retry
from airflow.providers.cncf.kubernetes.utils.container import (
container_is_completed,
container_is_running,
@@ -390,6 +389,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
self.log.debug("Response: %s", response)
return response
+ @generic_api_retry
def get_custom_object(
self, group: str, version: str, plural: str, name: str, namespace: str
| None = None
):
@@ -412,6 +412,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
)
return response
+ @generic_api_retry
def delete_custom_object(
self, group: str, version: str, plural: str, name: str, namespace: str
| None = None, **kwargs
):
@@ -540,12 +541,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
name=name, namespace=namespace, pretty=True, **kwargs
)
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_random_exponential(),
- reraise=True,
- retry=tenacity.retry_if_exception(should_retry_creation),
- )
+ @generic_api_retry
def create_job(
self,
job: V1Job,
@@ -572,6 +568,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
raise e
return resp
+ @generic_api_retry
def get_job(self, job_name: str, namespace: str) -> V1Job:
"""
Get Job of specified name and namespace.
@@ -582,6 +579,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
"""
return self.batch_v1_client.read_namespaced_job(name=job_name,
namespace=namespace, pretty=True)
+ @generic_api_retry
def get_job_status(self, job_name: str, namespace: str) -> V1Job:
"""
Get job with status of specified name and namespace.
@@ -611,6 +609,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
self.log.info("The job '%s' is incomplete. Sleeping for %i sec.",
job_name, job_poll_interval)
sleep(job_poll_interval)
+ @generic_api_retry
def list_jobs_all_namespaces(self) -> V1JobList:
"""
Get list of Jobs from all namespaces.
@@ -619,6 +618,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
"""
return self.batch_v1_client.list_job_for_all_namespaces(pretty=True)
+ @generic_api_retry
def list_jobs_from_namespace(self, namespace: str) -> V1JobList:
"""
Get list of Jobs from dedicated namespace.
@@ -674,6 +674,7 @@ class KubernetesHook(BaseHook, PodOperatorHookProtocol):
return bool(next((c for c in conditions if c.type == "Complete"
and c.status), None))
return False
+ @generic_api_retry
def patch_namespaced_job(self, job_name: str, namespace: str, body:
object) -> V1Job:
"""
Update the specified Job.
@@ -879,6 +880,7 @@ class AsyncKubernetesHook(KubernetesHook):
if kube_client is not None:
await kube_client.close()
+ @generic_api_retry
async def get_pod(self, name: str, namespace: str) -> V1Pod:
"""
Get pod's object.
@@ -899,6 +901,7 @@ class AsyncKubernetesHook(KubernetesHook):
raise KubernetesApiPermissionError("Permission denied
(403) from Kubernetes API.") from e
raise KubernetesApiError from e
+ @generic_api_retry
async def delete_pod(self, name: str, namespace: str):
"""
Delete pod's object.
@@ -917,6 +920,7 @@ class AsyncKubernetesHook(KubernetesHook):
if str(e.status) != "404":
raise
+ @generic_api_retry
async def read_logs(
self, name: str, namespace: str, container_name: str | None = None,
since_seconds: int | None = None
) -> list[str]:
@@ -949,6 +953,7 @@ class AsyncKubernetesHook(KubernetesHook):
except HTTPError as e:
raise KubernetesApiError from e
+ @generic_api_retry
async def get_pod_events(self, name: str, namespace: str) ->
CoreV1EventList:
"""Get pod's events."""
async with self.get_conn() as connection:
@@ -964,6 +969,7 @@ class AsyncKubernetesHook(KubernetesHook):
raise KubernetesApiPermissionError("Permission denied
(403) from Kubernetes API.") from e
raise KubernetesApiError from e
+ @generic_api_retry
async def get_job_status(self, name: str, namespace: str) -> V1Job:
"""
Get job's status object.
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py
index 19de125af3a..f1e232db9a9 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py
@@ -23,10 +23,14 @@ from functools import cache
from typing import TYPE_CHECKING
import pendulum
-from kubernetes.client.rest import ApiException
+import tenacity
+from kubernetes.client.rest import ApiException as SyncApiException
+from kubernetes_asyncio.client.exceptions import ApiException as
AsyncApiException
from slugify import slugify
+from urllib3.exceptions import HTTPError
from airflow.configuration import conf
+from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.backcompat import get_logical_date_key
if TYPE_CHECKING:
@@ -39,6 +43,62 @@ alphanum_lower = string.ascii_lowercase + string.digits
POD_NAME_MAX_LENGTH = 63 # Matches Linux kernel's HOST_NAME_MAX default value
minus 1.
+class PodLaunchFailedException(AirflowException):
+ """When pod launching fails in KubernetesPodOperator."""
+
+
+class KubernetesApiException(AirflowException):
+ """When communication with kubernetes API fails."""
+
+
+API_RETRIES = conf.getint("workers", "api_retries", fallback=5)
+API_RETRY_WAIT_MIN = conf.getfloat("workers", "api_retry_wait_min", fallback=1)
+API_RETRY_WAIT_MAX = conf.getfloat("workers", "api_retry_wait_max",
fallback=15)
+
+_default_wait = tenacity.wait_exponential(min=API_RETRY_WAIT_MIN,
max=API_RETRY_WAIT_MAX)
+
+TRANSIENT_STATUS_CODES = {409, 429, 500, 502, 503, 504}
+
+
+def _should_retry_api(exc: BaseException) -> bool:
+ """Retry on selected ApiException status codes, plus plain HTTP/timeout
errors."""
+ if isinstance(exc, (SyncApiException, AsyncApiException)):
+ return exc.status in TRANSIENT_STATUS_CODES
+ return isinstance(exc, (HTTPError, KubernetesApiException))
+
+
+class WaitRetryAfterOrExponential(tenacity.wait.wait_base):
+ """Wait strategy that honors Retry-After header on 429, else falls back to
exponential backoff."""
+
+ def __call__(self, retry_state):
+ exc = retry_state.outcome.exception() if retry_state.outcome else None
+ if isinstance(exc, (SyncApiException, AsyncApiException)) and
exc.status == 429:
+ retry_after = (exc.headers or {}).get("Retry-After")
+ if retry_after:
+ try:
+ return float(int(retry_after))
+ except ValueError:
+ pass
+ # Inline exponential fallback
+ return _default_wait(retry_state)
+
+
+def generic_api_retry(func):
+ """
+ Retry to Kubernetes API calls.
+
+ - Retries only transient ApiException status codes.
+ - Honors Retry-After on 429.
+ """
+ return tenacity.retry(
+ stop=tenacity.stop_after_attempt(API_RETRIES),
+ wait=WaitRetryAfterOrExponential(),
+ retry=tenacity.retry_if_exception(_should_retry_api),
+ reraise=True,
+ before_sleep=tenacity.before_sleep_log(log, logging.WARNING),
+ )(func)
+
+
def rand_str(num):
"""
Generate random lowercase alphanumeric string of length num.
@@ -148,18 +208,3 @@ def annotations_for_logging_task_metadata(annotation_set):
else:
annotations_for_logging = "<omitted>"
return annotations_for_logging
-
-
-def should_retry_creation(exception: BaseException) -> bool:
- """
- Check if an Exception indicates a transient error and warrants retrying.
-
- This function is needed for preventing 'No agent available' error. The
error appears time to time
- when users try to create a Resource or Job. This issue is inside
kubernetes and in the current moment
- has no solution. Like a temporary solution we decided to retry Job or
Resource creation request each
- time when this error appears.
- More about this issue here:
https://github.com/cert-manager/cert-manager/issues/6457
- """
- if isinstance(exception, ApiException):
- return str(exception.status) == "500"
- return False
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py
index 75fc79360c4..0e287b93f48 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/resource.py
@@ -23,13 +23,12 @@ from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING
-import tenacity
import yaml
from kubernetes.utils import create_from_yaml
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
-from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
should_retry_creation
+from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
generic_api_retry
from airflow.providers.cncf.kubernetes.utils.delete_from import
delete_from_yaml
from airflow.providers.cncf.kubernetes.utils.k8s_resource_iterator import
k8s_resource_iterator
from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_1_PLUS
@@ -132,12 +131,7 @@ class
KubernetesCreateResourceOperator(KubernetesResourceBaseOperator):
else:
self.custom_object_client.create_cluster_custom_object(group,
version, plural, body)
- @tenacity.retry(
- stop=tenacity.stop_after_attempt(3),
- wait=tenacity.wait_random_exponential(),
- reraise=True,
- retry=tenacity.retry_if_exception(should_retry_creation),
- )
+ @generic_api_retry
def _create_objects(self, objects):
self.log.info("Starting resource creation")
if not self.custom_resource_definition:
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index e67371f3e89..725ce6e48bc 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import asyncio
import enum
-import functools
import json
import math
import time
@@ -31,7 +30,6 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Literal, cast
import pendulum
-import tenacity
from kubernetes import client, watch
from kubernetes.client.rest import ApiException
from kubernetes.stream import stream as kubernetes_stream
@@ -39,9 +37,13 @@ from pendulum import DateTime
from pendulum.parsing.exceptions import ParserError
from urllib3.exceptions import HTTPError, TimeoutError
-from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode,
KubernetesPodOperatorCallback
+from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
+ KubernetesApiException,
+ PodLaunchFailedException,
+ generic_api_retry,
+)
from airflow.providers.cncf.kubernetes.utils.container import (
container_is_completed,
container_is_running,
@@ -73,76 +75,6 @@ Sentinel for no xcom result.
"""
-API_RETRIES = conf.getint("workers", "api_retries", fallback=5)
-API_RETRY_WAIT_MIN = conf.getfloat("workers", "api_retry_wait_min", fallback=1)
-API_RETRY_WAIT_MAX = conf.getfloat("workers", "api_retry_wait_max",
fallback=15)
-
-_default_wait = tenacity.wait_exponential(min=API_RETRY_WAIT_MIN,
max=API_RETRY_WAIT_MAX)
-
-
-def get_retry_after_seconds(retry_state) -> int:
- """Extract Retry-After header from ApiException if present and log wait
time."""
- exception = retry_state.outcome.exception() if retry_state.outcome else
None
- if exception and isinstance(exception, ApiException) and
str(exception.status) == "429":
- retry_after = exception.headers.get("Retry-After") if
exception.headers else None
- if retry_after:
- try:
- wait_seconds = int(retry_after)
- return wait_seconds
- except ValueError:
- pass
- # Default exponential backoff
- wait_seconds = int(_default_wait(retry_state))
- return wait_seconds
-
-
-def generic_api_retry(func):
- """Apply tenacity retry logic for generic Kubernetes API calls."""
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- retry_decorator = tenacity.retry(
- stop=tenacity.stop_after_attempt(API_RETRIES),
- wait=get_retry_after_seconds,
- reraise=True,
- )
- return retry_decorator(func)(*args, **kwargs)
-
- return wrapper
-
-
-def should_retry_start_pod(exception: BaseException) -> bool:
- """Check if an Exception indicates a transient error and warrants
retrying."""
- if isinstance(exception, ApiException):
- # Retry on 409 (conflict) and 429 (too many requests)
- return str(exception.status) in ("409", "429")
- return False
-
-
-def create_pod_api_retry(func):
- """
- Apply tenacity retry logic for pod creation.
-
- Retries on 409 and 429 errors, and respects Retry-After header for 429.
- """
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- retry_decorator = tenacity.retry(
- stop=tenacity.stop_after_attempt(API_RETRIES),
- wait=get_retry_after_seconds,
- reraise=True,
- retry=tenacity.retry_if_exception(should_retry_start_pod),
- )
- return retry_decorator(func)(*args, **kwargs)
-
- return wrapper
-
-
-class PodLaunchFailedException(AirflowException):
- """When pod launching fails in KubernetesPodOperator."""
-
-
class PodPhase:
"""
Possible pod phases.
@@ -405,6 +337,7 @@ class PodManager(LoggingMixin):
raise e
return resp
+ @generic_api_retry
def delete_pod(self, pod: V1Pod) -> None:
"""Delete POD."""
try:
@@ -416,7 +349,7 @@ class PodManager(LoggingMixin):
if str(e.status) != "404":
raise
- @create_pod_api_retry
+ @generic_api_retry
def create_pod(self, pod: V1Pod) -> V1Pod:
"""Launch the pod asynchronously."""
return self.run_pod_async(pod)
@@ -817,7 +750,6 @@ class PodManager(LoggingMixin):
post_termination_timeout=post_termination_timeout,
)
- @generic_api_retry
def get_init_container_names(self, pod: V1Pod) -> list[str]:
"""
Return container names from the POD except for the
airflow-xcom-sidecar container.
@@ -826,7 +758,6 @@ class PodManager(LoggingMixin):
"""
return [container_spec.name for container_spec in
pod.spec.init_containers]
- @generic_api_retry
def get_container_names(self, pod: V1Pod) -> list[str]:
"""
Return container names from the POD except for the
airflow-xcom-sidecar container.
@@ -848,7 +779,7 @@ class PodManager(LoggingMixin):
namespace=pod.metadata.namespace,
field_selector=f"involvedObject.name={pod.metadata.name}"
)
except HTTPError as e:
- raise AirflowException(f"There was an error reading the kubernetes
API: {e}")
+ raise KubernetesApiException(f"There was an error reading the
kubernetes API: {e}")
@generic_api_retry
def read_pod(self, pod: V1Pod) -> V1Pod:
@@ -856,7 +787,7 @@ class PodManager(LoggingMixin):
try:
return self._client.read_namespaced_pod(pod.metadata.name,
pod.metadata.namespace)
except HTTPError as e:
- raise AirflowException(f"There was an error reading the kubernetes
API: {e}")
+ raise KubernetesApiException(f"There was an error reading the
kubernetes API: {e}")
def await_xcom_sidecar_container_start(
self, pod: V1Pod, timeout: int = 900, log_interval: int = 30
@@ -1040,7 +971,6 @@ class AsyncPodManager(LoggingMixin):
self._callbacks = callbacks or []
self.stop_watching_events = False
- @generic_api_retry
async def read_pod(self, pod: V1Pod) -> V1Pod:
"""Read POD information."""
return await self._hook.get_pod(
@@ -1048,7 +978,6 @@ class AsyncPodManager(LoggingMixin):
pod.metadata.namespace,
)
- @generic_api_retry
async def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
"""Get pod's events."""
return await self._hook.get_pod_events(
@@ -1082,7 +1011,6 @@ class AsyncPodManager(LoggingMixin):
check_interval=check_interval,
)
- @generic_api_retry
async def fetch_container_logs_before_current_sec(
self, pod: V1Pod, container_name: str, since_time: DateTime | None =
None
) -> DateTime | None:
diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py
index 7e99e6fa2d6..a467427eaa6 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/conftest.py
@@ -18,12 +18,34 @@
from __future__ import annotations
from pathlib import Path
+from unittest import mock
import pytest
DATA_FILE_DIRECTORY = Path(__file__).resolve().parent / "data_files"
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers", "no_wait_patch_disabled: disable autouse
WaitRetryAfterOrExponential patch"
+ )
+
+
[email protected](autouse=True)
+def no_retry_wait(request):
+ # Skip patching if test has marker
+ if request.node.get_closest_marker("no_wait_patch_disabled"):
+ yield
+ return
+ patcher = mock.patch(
+
"airflow.providers.cncf.kubernetes.kubernetes_helper_functions.WaitRetryAfterOrExponential.__call__",
+ return_value=0,
+ )
+ patcher.start()
+ yield
+ patcher.stop()
+
+
@pytest.fixture
def data_file():
"""Helper fixture for obtain data file from data directory."""
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py
index 616c73b5e82..6d13f5eb38f 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py
@@ -677,19 +677,20 @@ class TestKubernetesHook:
@patch(f"{HOOK_MODULE}.json.dumps")
@patch(f"{HOOK_MODULE}.KubernetesHook.batch_v1_client")
- def test_create_job_retries_three_times(self, mock_client,
mock_json_dumps):
+ def test_create_job_retries_five_times(self, mock_client, mock_json_dumps):
mock_client.create_namespaced_job.side_effect = [
ApiException(status=500),
ApiException(status=500),
ApiException(status=500),
ApiException(status=500),
+ ApiException(status=500),
]
hook = KubernetesHook()
with pytest.raises(ApiException):
hook.create_job(job=mock.MagicMock())
- assert mock_client.create_namespaced_job.call_count == 3
+ assert mock_client.create_namespaced_job.call_count == 5
@pytest.mark.parametrize(
("given_namespace", "expected_namespace"),
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_resource.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_resource.py
index 9f4f004ce50..3bea6a051a2 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_resource.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_resource.py
@@ -277,12 +277,13 @@ class TestKubernetesXResourceOperator:
@patch("kubernetes.config.load_kube_config")
@patch("airflow.providers.cncf.kubernetes.operators.resource.create_from_yaml")
- def test_create_objects_retries_three_times(self, mock_create_from_yaml,
mock_load_kube_config, context):
+ def test_create_objects_retries_five_times(self, mock_create_from_yaml,
mock_load_kube_config, context):
mock_create_from_yaml.side_effect = [
ApiException(status=500),
ApiException(status=500),
ApiException(status=500),
ApiException(status=500),
+ ApiException(status=500),
]
op = KubernetesCreateResourceOperator(
@@ -295,4 +296,4 @@ class TestKubernetesXResourceOperator:
with pytest.raises(ApiException):
op.execute(context)
- assert mock_create_from_yaml.call_count == 3
+ assert mock_create_from_yaml.call_count == 5
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_kubernetes_helper_functions.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_kubernetes_helper_functions.py
index 12fb2071704..3622a5c34f4 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_kubernetes_helper_functions.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/test_kubernetes_helper_functions.py
@@ -18,14 +18,73 @@
from __future__ import annotations
import re
+from unittest import mock
import pytest
+from kubernetes.client.rest import ApiException as SyncApiException
+from kubernetes_asyncio.client.exceptions import ApiException as
AsyncApiException
+from urllib3.exceptions import HTTPError
-from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import
create_unique_id
+from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
+ KubernetesApiException,
+ WaitRetryAfterOrExponential,
+ _should_retry_api,
+ create_unique_id,
+)
pod_name_regex =
r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$"
+class DummyRetryState:
+ def __init__(self, exception=None):
+ self.outcome = mock.Mock() if exception is not None else None
+ if self.outcome:
+ self.outcome.exception = mock.Mock(return_value=exception)
+
+
+def test_should_retry_api():
+ exc = HTTPError()
+ assert _should_retry_api(exc)
+
+ exc = KubernetesApiException()
+ assert _should_retry_api(exc)
+
+ exc = SyncApiException(status=500)
+ assert _should_retry_api(exc)
+
+ exc = AsyncApiException(status=500)
+ assert _should_retry_api(exc)
+
+ exc = SyncApiException(status=404)
+ assert not _should_retry_api(exc)
+
+ exc = AsyncApiException(status=404)
+ assert not _should_retry_api(exc)
+
+
+class TestWaitRetryAfterOrExponential:
+ @pytest.mark.no_wait_patch_disabled
+ @pytest.mark.parametrize(("exception"), [SyncApiException,
AsyncApiException])
+ def test_call_with_retry_after_header(self, exception):
+ exc = exception(status=429)
+ exc.headers = {"Retry-After": "15"}
+ retry_state = DummyRetryState(exception=exc)
+ wait = WaitRetryAfterOrExponential()(retry_state)
+ assert wait == 15
+
+ @pytest.mark.no_wait_patch_disabled
+ @pytest.mark.parametrize(
+ ("attempt_number", "expected_wait", "exception"),
+ [(1, 1, SyncApiException), (4, 8, AsyncApiException)],
+ )
+ def test_call_without_retry_after_header(self, attempt_number,
expected_wait, exception):
+ exc = exception(status=409)
+ retry_state = DummyRetryState(exception=exc)
+ retry_state.attempt_number = attempt_number
+ wait = WaitRetryAfterOrExponential()(retry_state)
+ assert wait == expected_wait
+
+
class TestCreateUniqueId:
@pytest.mark.parametrize(
("val", "expected"),
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py
index 187e085ad5c..1688470946b 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_pod_manager.py
@@ -36,7 +36,6 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager
import (
PodLogsConsumer,
PodManager,
PodPhase,
- get_retry_after_seconds,
parse_log_line,
)
from airflow.utils.timezone import utc
@@ -47,10 +46,6 @@ if TYPE_CHECKING:
from pendulum import DateTime
-def wait_none(retry_state):
- return 0
-
-
def test_parse_log_line():
log_message = "This should return no timestamp"
timestamp, line = parse_log_line(log_message)
@@ -63,31 +58,6 @@ def test_parse_log_line():
assert line == log_message
-class DummyRetryState:
- def __init__(self, exception=None):
- # self.attempt_number = 1
- self.outcome = mock.Mock() if exception is not None else None
- if self.outcome:
- self.outcome.exception = mock.Mock(return_value=exception)
-
-
-def test_get_retry_after_seconds_with_retry_after_header():
- exc = ApiException(status=429)
- exc.headers = {"Retry-After": "15"}
- retry_state = DummyRetryState(exception=exc)
- wait = get_retry_after_seconds(retry_state)
- assert wait == 15
-
-
[email protected](("attempt_number", "expected_wait"), [(1, 1), (4, 8)])
-def test_get_retry_after_seconds_without_retry_after_header(attempt_number,
expected_wait):
- exc = ApiException(status=409)
- retry_state = DummyRetryState(exception=exc)
- retry_state.attempt_number = attempt_number
- wait = get_retry_after_seconds(retry_state)
- assert wait == expected_wait
-
-
class TestPodManager:
def setup_method(self):
self.mock_kube_client = mock.Mock()
@@ -103,7 +73,6 @@ class TestPodManager:
assert isinstance(logs, PodLogsConsumer)
assert logs.response == mock.sentinel.logs
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_logs_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [
@@ -149,7 +118,6 @@ class TestPodManager:
self.pod_manager.fetch_container_logs(mock.MagicMock(),
"container-name", follow=True)
assert "[container-name] None" not in (record.message for record
in caplog.records)
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_logs_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod_log.side_effect = [
@@ -242,7 +210,6 @@ class TestPodManager:
events = self.pod_manager.read_pod_events(mock.sentinel)
assert mock.sentinel.events == events
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_events_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.list_namespaced_event.side_effect = [
@@ -264,7 +231,6 @@ class TestPodManager:
]
)
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_events_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.list_namespaced_event.side_effect = [
@@ -283,7 +249,6 @@ class TestPodManager:
pod_info = self.pod_manager.read_pod(mock.sentinel)
assert mock.sentinel.pod_info == pod_info
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_retries_successfully(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod.side_effect = [
@@ -316,7 +281,6 @@ class TestPodManager:
self.mock_kube_client.read_namespaced_pod_log.return_value =
mock_response
self.pod_manager.fetch_container_logs(mock.sentinel, "base")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_monitor_pod_logs_failures_non_fatal(self):
mock.sentinel.metadata = mock.MagicMock()
running_status = mock.MagicMock()
@@ -340,7 +304,6 @@ class TestPodManager:
self.pod_manager.fetch_container_logs(mock.sentinel, "base")
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
def test_read_pod_retries_fails(self):
mock.sentinel.metadata = mock.MagicMock()
self.mock_kube_client.read_namespaced_pod.side_effect = [
@@ -386,7 +349,6 @@ class TestPodManager:
]
)
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
def test_fetch_container_logs_failures(self, mock_container_is_running):
MockWrapper.reset()
@@ -440,7 +402,6 @@ class TestPodManager:
assert "ERROR" not in caplog.text
@pytest.mark.parametrize("status", [409, 429])
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.run_pod_async")
def test_start_pod_retries_on_409_or_429_error(self, mock_run_pod_async,
status):
mock_run_pod_async.side_effect = [
@@ -450,14 +411,12 @@ class TestPodManager:
self.pod_manager.create_pod(mock.sentinel)
assert mock_run_pod_async.call_count == 2
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.run_pod_async")
def test_start_pod_fails_on_other_exception(self, mock_run_pod_async):
- mock_run_pod_async.side_effect = [ApiException(status=504)]
+ mock_run_pod_async.side_effect = [ApiException(status=401)]
with pytest.raises(ApiException):
self.pod_manager.create_pod(mock.sentinel)
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.run_pod_async")
def test_start_pod_retries_three_times(self, mock_run_pod_async):
mock_run_pod_async.side_effect = [
@@ -679,7 +638,6 @@ class TestPodManager:
assert ret == xcom_json
assert mock_exec_xcom_kill.call_count == 1
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill")
def test_extract_xcom_failure(self, mock_exec_xcom_kill,
mock_kubernetes_stream):
@@ -708,7 +666,6 @@ class TestPodManager:
assert ret == xcom_result
assert mock_exec_xcom_kill.call_count == 1
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill")
def test_extract_xcom_none(self, mock_exec_xcom_kill,
mock_kubernetes_stream):
@@ -722,7 +679,6 @@ class TestPodManager:
self.pod_manager.extract_xcom(pod=mock_pod)
assert mock_exec_xcom_kill.call_count == 1
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_terminated")
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running")
def test_await_xcom_sidecar_container_timeout(
@@ -972,7 +928,6 @@ class TestAsyncPodManager:
unexpected_call = mock.call("[%s] %s", container_name,
not_expected)
assert unexpected_call not in mock_log_info.mock_calls
-
@mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.get_retry_after_seconds",
wait_none)
@pytest.mark.asyncio
async def
test_fetch_container_logs_before_current_sec_error_handling(self):
pod = mock.MagicMock()