This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f0c223951 Properly return None values from methods through Internal 
API (#33927)
7f0c223951 is described below

commit 7f0c223951f8e06eae0105ad646b21feaec79aea
Author: mhenc <[email protected]>
AuthorDate: Thu Aug 31 18:18:29 2023 +0200

    Properly return None values from methods through Internal API (#33927)
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py    | 14 +++++---------
 airflow/api_internal/internal_api_call.py             |  4 +++-
 tests/api_internal/endpoints/test_rpc_api_endpoint.py | 10 ++++++----
 tests/api_internal/test_internal_api_call.py          | 17 +++++++++++++++++
 4 files changed, 31 insertions(+), 14 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index f241e1149b..926c955190 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -90,19 +90,15 @@ def internal_airflow_api(body: dict[str, Any]) -> 
APIResponse:
             params_json = json.loads(str(body.get("params")))
             params = BaseSerialization.deserialize(params_json, 
use_pydantic_models=True)
     except Exception as err:
-        log.error("Error deserializing parameters.")
-        log.error(err)
+        log.error("Error (%s) when deserializing parameters: %s", err, 
params_json)
         return Response(response="Error deserializing parameters.", status=400)
 
-    log.debug("Calling method %.", {method_name})
+    log.debug("Calling method %s.", method_name)
     try:
         output = handler(**params)
         output_json = BaseSerialization.serialize(output, 
use_pydantic_models=True)
-        log.debug("Returning response")
-        return Response(
-            response=json.dumps(output_json or "{}"), headers={"Content-Type": 
"application/json"}
-        )
+        response = json.dumps(output_json) if output_json is not None else None
+        return Response(response=response, headers={"Content-Type": 
"application/json"})
     except Exception as e:
-        log.error("Error when calling method %s.", method_name)
-        log.error(e)
+        log.error("Error (%s) when calling method %s.", e, method_name)
         return Response(response=f"Error executing method: {method_name}.", 
status=500)
diff --git a/airflow/api_internal/internal_api_call.py 
b/airflow/api_internal/internal_api_call.py
index d9c49a138f..698bc26383 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -108,7 +108,7 @@ def internal_api_call(func: Callable[PS, RT]) -> 
Callable[PS, RT]:
         return response.content
 
     @wraps(func)
-    def wrapper(*args, **kwargs) -> RT:
+    def wrapper(*args, **kwargs):
         use_internal_api = InternalApiConfig.get_use_internal_api()
         if not use_internal_api:
             return func(*args, **kwargs)
@@ -125,6 +125,8 @@ def internal_api_call(func: Callable[PS, RT]) -> 
Callable[PS, RT]:
         args_json = json.dumps(BaseSerialization.serialize(arguments_dict, 
use_pydantic_models=True))
         method_name = f"{func.__module__}.{func.__qualname__}"
         result = make_jsonrpc_request(method_name, args_json)
+        if result is None or result == b"":
+            return None
         return BaseSerialization.deserialize(json.loads(result), 
use_pydantic_models=True)
 
     return wrapper
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py 
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
index 0c8c7c0c0b..a81d03f6ce 100644
--- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -80,7 +80,7 @@ class TestRpcApiEndpoint:
     @pytest.mark.parametrize(
         "input_params, method_result, result_cmp_func, method_params",
         [
-            ("", None, equals, {}),
+            ("", None, lambda got, _: got == b"", {}),
             ("", "test_me", equals, {}),
             (
                 json.dumps(BaseSerialization.serialize({"dag_id": 15, 
"task_id": "fake-task"})),
@@ -103,8 +103,7 @@ class TestRpcApiEndpoint:
         ],
     )
     def test_method(self, input_params, method_result, result_cmp_func, 
method_params):
-        if method_result:
-            mock_test_method.return_value = method_result
+        mock_test_method.return_value = method_result
 
         input_data = {
             "jsonrpc": "2.0",
@@ -119,7 +118,10 @@ class TestRpcApiEndpoint:
         assert response.status_code == 200
         if method_result:
             response_data = 
BaseSerialization.deserialize(json.loads(response.data), 
use_pydantic_models=True)
-            assert result_cmp_func(response_data, method_result)
+        else:
+            response_data = response.data
+
+        assert result_cmp_func(response_data, method_result)
 
         mock_test_method.assert_called_once_with(**method_params)
 
diff --git a/tests/api_internal/test_internal_api_call.py 
b/tests/api_internal/test_internal_api_call.py
index c3eda6e184..9d96fa0756 100644
--- a/tests/api_internal/test_internal_api_call.py
+++ b/tests/api_internal/test_internal_api_call.py
@@ -142,6 +142,23 @@ class TestInternalApiCall:
             headers={"Content-Type": "application/json"},
         )
 
+    @conf_vars(
+        {
+            ("core", "database_access_isolation"): "true",
+            ("core", "internal_api_url"): "http://localhost:8888";,
+        }
+    )
+    @mock.patch("airflow.api_internal.internal_api_call.requests")
+    def test_remote_call_with_none_result(self, mock_requests):
+        response = requests.Response()
+        response.status_code = 200
+        response._content = b""
+
+        mock_requests.post.return_value = response
+
+        result = TestInternalApiCall.fake_method()
+        assert result is None
+
     @conf_vars(
         {
             ("core", "database_access_isolation"): "true",

Reply via email to