This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 61412b3af83 Add HTTP retry handling into task SDK api.client (#45121)
61412b3af83 is described below

commit 61412b3af83610cf21588ef3ddaff300dce50f37
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Dec 27 23:51:46 2024 +0100

    Add HTTP retry handling into task SDK api.client (#45121)
    
    * Add HTTP retry handling into task SDK api.client
    
    * Add logging of call failures
    
    * Prevent task sdk tests with LocalExecutor fail with retries
    
    * Review feedback
    
    * Review feedback, Adjust wording
    
    * Correct time parameters to float
    
    * Review Feedback
---
 task_sdk/pyproject.toml                            |   1 +
 task_sdk/src/airflow/sdk/api/client.py             |  28 ++++
 task_sdk/tests/api/test_client.py                  | 157 ++++++++++++++++-----
 .../commands/remote_commands/test_task_command.py  |   2 +
 4 files changed, 155 insertions(+), 33 deletions(-)

diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index aa0271c85fc..a27f4cb7c91 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -30,6 +30,7 @@ dependencies = [
     "msgspec>=0.18.6",
     "psutil>=6.1.0",
     "structlog>=24.4.0",
+    "retryhttp>=1.2.0",
 ]
 classifiers = [
   "Framework :: Apache Airflow",
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index 7488ef3e88a..ee4144c7f54 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -17,6 +17,8 @@
 
 from __future__ import annotations
 
+import logging
+import os
 import sys
 import uuid
 from http import HTTPStatus
@@ -26,6 +28,8 @@ import httpx
 import msgspec
 import structlog
 from pydantic import BaseModel
+from retryhttp import retry, wait_retry_after
+from tenacity import before_log, wait_random_exponential
 from uuid6 import uuid7
 
 from airflow.sdk import __version__
@@ -268,6 +272,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
     return httpx.Response(200, json={"text": "Hello, world!"})
 
 
+# Config options for SDK how retries on HTTP requests should be handled
+# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 
3:37 and fails after 5:07min
+# So far there is no other config facility in SDK we use ENV for the moment
+# TODO: Consider these env variables while handling airflow confs in task sdk
+API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
+API_RETRY_WAIT_MIN = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 
1.0))
+API_RETRY_WAIT_MAX = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 
90.0))
+
+
 class Client(httpx.Client):
     def __init__(self, *, base_url: str | None, dry_run: bool = False, token: 
str, **kwargs: Any):
         if (not base_url) ^ dry_run:
@@ -289,6 +302,21 @@ class Client(httpx.Client):
             **kwargs,
         )
 
+    _default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, 
max=API_RETRY_WAIT_MAX)
+
+    @retry(
+        reraise=True,
+        max_attempt_number=API_RETRIES,
+        wait_server_errors=_default_wait,
+        wait_network_errors=_default_wait,
+        wait_timeouts=_default_wait,
+        wait_rate_limited=wait_retry_after(fallback=_default_wait),  # No 
infinite timeout on HTTP 429
+        before_sleep=before_log(log, logging.WARNING),
+    )
+    def request(self, *args, **kwargs):
+        """Implement a convenience for httpx.Client.request with a retry 
layer."""
+        return super().request(*args, **kwargs)
+
     # We "group" or "namespace" operations by what they operate on, rather 
than a flat namespace with all
     # methods on one object prefixed with the object type 
(`.task_instances.update` rather than
     # `task_instance_update` etc.)
diff --git a/task_sdk/tests/api/test_client.py 
b/task_sdk/tests/api/test_client.py
index 279502793ee..c52feb96766 100644
--- a/task_sdk/tests/api/test_client.py
+++ b/task_sdk/tests/api/test_client.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+from unittest import mock
 
 import httpx
 import pytest
@@ -30,18 +31,28 @@ from airflow.utils import timezone
 from airflow.utils.state import TerminalTIState
 
 
-class TestClient:
-    def test_error_parsing(self):
-        def handle_request(request: httpx.Request) -> httpx.Response:
-            """
-            A transport handle that always returns errors
-            """
+def make_client(transport: httpx.MockTransport) -> Client:
+    """Get a client with a custom transport"""
+    return Client(base_url="test://server", token="", transport=transport)
 
-            return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": 
"err", "type": "required"}]})
 
-        client = Client(
-            base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
-        )
+def make_client_w_responses(responses: list[httpx.Response]) -> Client:
+    """Helper fixture to create a mock client with custom responses."""
+
+    def handle_request(request: httpx.Request) -> httpx.Response:
+        return responses.pop(0)
+
+    return Client(
+        base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
+    )
+
+
+class TestClient:
+    def test_error_parsing(self):
+        responses = [
+            httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", 
"type": "required"}]})
+        ]
+        client = make_client_w_responses(responses)
 
         with pytest.raises(ServerResponseError) as err:
             client.get("http://error";)
@@ -53,39 +64,92 @@ class TestClient:
         ]
 
     def test_error_parsing_plain_text(self):
-        def handle_request(request: httpx.Request) -> httpx.Response:
-            """
-            A transport handle that always returns errors
-            """
-
-            return httpx.Response(422, content=b"Internal Server Error")
-
-        client = Client(
-            base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
-        )
+        responses = [httpx.Response(422, content=b"Internal Server Error")]
+        client = make_client_w_responses(responses)
 
         with pytest.raises(httpx.HTTPStatusError) as err:
             client.get("http://error";)
         assert not isinstance(err.value, ServerResponseError)
 
     def test_error_parsing_other_json(self):
-        def handle_request(request: httpx.Request) -> httpx.Response:
-            # Some other json than an error body.
-            return httpx.Response(404, json={"detail": "Not found"})
-
-        client = Client(
-            base_url=None, dry_run=True, token="", mounts={"'http://": 
httpx.MockTransport(handle_request)}
-        )
+        responses = [httpx.Response(404, json={"detail": "Not found"})]
+        client = make_client_w_responses(responses)
 
         with pytest.raises(ServerResponseError) as err:
             client.get("http://error";)
         assert err.value.args == ("Not found",)
         assert err.value.detail is None
 
+    @mock.patch("time.sleep", return_value=None)
+    def test_retry_handling_unrecoverable_error(self, mock_sleep):
+        responses: list[httpx.Response] = [
+            *[httpx.Response(500, text="Internal Server Error")] * 11,
+            httpx.Response(200, json={"detail": "Recovered from error - but 
will fail before"}),
+            httpx.Response(400, json={"detail": "Should not get here"}),
+        ]
+        client = make_client_w_responses(responses)
 
-def make_client(transport: httpx.MockTransport) -> Client:
-    """Get a client with a custom transport"""
-    return Client(base_url="test://server", token="", transport=transport)
+        with pytest.raises(httpx.HTTPStatusError) as err:
+            client.get("http://error";)
+        assert not isinstance(err.value, ServerResponseError)
+        assert len(responses) == 3
+        assert mock_sleep.call_count == 9
+
+    @mock.patch("time.sleep", return_value=None)
+    def test_retry_handling_recovered(self, mock_sleep):
+        responses: list[httpx.Response] = [
+            *[httpx.Response(500, text="Internal Server Error")] * 3,
+            httpx.Response(200, json={"detail": "Recovered from error"}),
+            httpx.Response(400, json={"detail": "Should not get here"}),
+        ]
+        client = make_client_w_responses(responses)
+
+        response = client.get("http://error";)
+        assert response.status_code == 200
+        assert len(responses) == 1
+        assert mock_sleep.call_count == 3
+
+    @mock.patch("time.sleep", return_value=None)
+    def test_retry_handling_overload(self, mock_sleep):
+        responses: list[httpx.Response] = [
+            httpx.Response(429, text="I am really busy atm, please back-off", 
headers={"Retry-After": "37"}),
+            httpx.Response(200, json={"detail": "Recovered from error"}),
+            httpx.Response(400, json={"detail": "Should not get here"}),
+        ]
+        client = make_client_w_responses(responses)
+
+        response = client.get("http://error";)
+        assert response.status_code == 200
+        assert len(responses) == 1
+        assert mock_sleep.call_count == 1
+        assert mock_sleep.call_args[0][0] == 37
+
+    @mock.patch("time.sleep", return_value=None)
+    def test_retry_handling_non_retry_error(self, mock_sleep):
+        responses: list[httpx.Response] = [
+            httpx.Response(422, json={"detail": "Somehow this is a bad 
request"}),
+            httpx.Response(400, json={"detail": "Should not get here"}),
+        ]
+        client = make_client_w_responses(responses)
+
+        with pytest.raises(ServerResponseError) as err:
+            client.get("http://error";)
+        assert len(responses) == 1
+        assert mock_sleep.call_count == 0
+        assert err.value.args == ("Somehow this is a bad request",)
+
+    @mock.patch("time.sleep", return_value=None)
+    def test_retry_handling_ok(self, mock_sleep):
+        responses: list[httpx.Response] = [
+            httpx.Response(200, json={"detail": "Recovered from error"}),
+            httpx.Response(400, json={"detail": "Should not get here"}),
+        ]
+        client = make_client_w_responses(responses)
+
+        response = client.get("http://error";)
+        assert response.status_code == 200
+        assert len(responses) == 1
+        assert mock_sleep.call_count == 0
 
 
 class TestTaskInstanceOperations:
@@ -95,7 +159,8 @@ class TestTaskInstanceOperations:
     response parsing.
     """
 
-    def test_task_instance_start(self, make_ti_context):
+    @mock.patch("time.sleep", return_value=None)  # To have retries not 
slowing down tests
+    def test_task_instance_start(self, mock_sleep, make_ti_context):
         # Simulate a successful response from the server that starts a task
         ti_id = uuid6.uuid7()
         start_date = "2024-10-31T12:00:00Z"
@@ -105,7 +170,14 @@ class TestTaskInstanceOperations:
             run_type="manual",
         )
 
+        # ...including a validation that retry really works
+        call_count = 0
+
         def handle_request(request: httpx.Request) -> httpx.Response:
+            nonlocal call_count
+            call_count += 1
+            if call_count < 4:
+                return httpx.Response(status_code=500, json={"detail": 
"Internal Server Error"})
             if request.url.path == f"/task-instances/{ti_id}/run":
                 actual_body = json.loads(request.read())
                 assert actual_body["pid"] == 100
@@ -120,6 +192,7 @@ class TestTaskInstanceOperations:
         client = make_client(transport=httpx.MockTransport(handle_request))
         resp = client.task_instances.start(ti_id, 100, start_date)
         assert resp == ti_context
+        assert call_count == 4
 
     @pytest.mark.parametrize("state", [state for state in TerminalTIState])
     def test_task_instance_finish(self, state):
@@ -245,9 +318,17 @@ class TestVariableOperations:
     response parsing.
     """
 
-    def test_variable_get_success(self):
+    @mock.patch("time.sleep", return_value=None)  # To have retries not 
slowing down tests
+    def test_variable_get_success(self, mock_sleep):
         # Simulate a successful response from the server with a variable
+        # ...including a validation that retry really works
+        call_count = 0
+
         def handle_request(request: httpx.Request) -> httpx.Response:
+            nonlocal call_count
+            call_count += 1
+            if call_count < 2:
+                return httpx.Response(status_code=500, json={"detail": 
"Internal Server Error"})
             if request.url.path == "/variables/test_key":
                 return httpx.Response(
                     status_code=200,
@@ -261,6 +342,7 @@ class TestVariableOperations:
         assert isinstance(result, VariableResponse)
         assert result.key == "test_key"
         assert result.value == "test_value"
+        assert call_count == 2
 
     def test_variable_not_found(self):
         # Simulate a 404 response from the server
@@ -323,9 +405,17 @@ class TestXCOMOperations:
             pytest.param({"key": "test_key", "value": {"key2": "value2"}}, 
id="nested-dict-value"),
         ],
     )
-    def test_xcom_get_success(self, value):
+    @mock.patch("time.sleep", return_value=None)  # To have retries not 
slowing down tests
+    def test_xcom_get_success(self, mock_sleep, value):
         # Simulate a successful response from the server when getting an xcom
+        # ...including a validation that retry really works
+        call_count = 0
+
         def handle_request(request: httpx.Request) -> httpx.Response:
+            nonlocal call_count
+            call_count += 1
+            if call_count < 3:
+                return httpx.Response(status_code=500, json={"detail": 
"Internal Server Error"})
             if request.url.path == "/xcoms/dag_id/run_id/task_id/key":
                 return httpx.Response(
                     status_code=201,
@@ -343,6 +433,7 @@ class TestXCOMOperations:
         assert isinstance(result, XComResponse)
         assert result.key == "test_key"
         assert result.value == value
+        assert call_count == 3
 
     def test_xcom_get_success_with_map_index(self):
         # Simulate a successful response from the server when getting an xcom 
with map_index passed
diff --git a/tests/cli/commands/remote_commands/test_task_command.py 
b/tests/cli/commands/remote_commands/test_task_command.py
index 66177c2d84e..843d6817cdc 100644
--- a/tests/cli/commands/remote_commands/test_task_command.py
+++ b/tests/cli/commands/remote_commands/test_task_command.py
@@ -496,6 +496,8 @@ class TestCliTasks:
                 mock.patch(
                     
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
                 ) as get_default_mock,
+                mock.patch("airflow.executors.local_executor.SimpleQueue"),  # 
Prevent a task being queued
+                
mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
             ):
                 EmptyOperator(task_id="task1")
                 EmptyOperator(task_id="task2", executor="foo_executor_alias")

Reply via email to