This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7f0c223951 Properly return None values from methods through Internal
API (#33927)
7f0c223951 is described below
commit 7f0c223951f8e06eae0105ad646b21feaec79aea
Author: mhenc <[email protected]>
AuthorDate: Thu Aug 31 18:18:29 2023 +0200
Properly return None values from methods through Internal API (#33927)
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 14 +++++---------
airflow/api_internal/internal_api_call.py | 4 +++-
tests/api_internal/endpoints/test_rpc_api_endpoint.py | 10 ++++++----
tests/api_internal/test_internal_api_call.py | 17 +++++++++++++++++
4 files changed, 31 insertions(+), 14 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index f241e1149b..926c955190 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -90,19 +90,15 @@ def internal_airflow_api(body: dict[str, Any]) ->
APIResponse:
params_json = json.loads(str(body.get("params")))
params = BaseSerialization.deserialize(params_json,
use_pydantic_models=True)
except Exception as err:
- log.error("Error deserializing parameters.")
- log.error(err)
+ log.error("Error (%s) when deserializing parameters: %s", err,
params_json)
return Response(response="Error deserializing parameters.", status=400)
- log.debug("Calling method %.", {method_name})
+ log.debug("Calling method %s.", method_name)
try:
output = handler(**params)
output_json = BaseSerialization.serialize(output,
use_pydantic_models=True)
- log.debug("Returning response")
- return Response(
- response=json.dumps(output_json or "{}"), headers={"Content-Type":
"application/json"}
- )
+ response = json.dumps(output_json) if output_json is not None else None
+ return Response(response=response, headers={"Content-Type":
"application/json"})
except Exception as e:
- log.error("Error when calling method %s.", method_name)
- log.error(e)
+ log.error("Error (%s) when calling method %s.", e, method_name)
return Response(response=f"Error executing method: {method_name}.",
status=500)
diff --git a/airflow/api_internal/internal_api_call.py
b/airflow/api_internal/internal_api_call.py
index d9c49a138f..698bc26383 100644
--- a/airflow/api_internal/internal_api_call.py
+++ b/airflow/api_internal/internal_api_call.py
@@ -108,7 +108,7 @@ def internal_api_call(func: Callable[PS, RT]) ->
Callable[PS, RT]:
return response.content
@wraps(func)
- def wrapper(*args, **kwargs) -> RT:
+ def wrapper(*args, **kwargs):
use_internal_api = InternalApiConfig.get_use_internal_api()
if not use_internal_api:
return func(*args, **kwargs)
@@ -125,6 +125,8 @@ def internal_api_call(func: Callable[PS, RT]) ->
Callable[PS, RT]:
args_json = json.dumps(BaseSerialization.serialize(arguments_dict,
use_pydantic_models=True))
method_name = f"{func.__module__}.{func.__qualname__}"
result = make_jsonrpc_request(method_name, args_json)
+ if result is None or result == b"":
+ return None
return BaseSerialization.deserialize(json.loads(result),
use_pydantic_models=True)
return wrapper
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
index 0c8c7c0c0b..a81d03f6ce 100644
--- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -80,7 +80,7 @@ class TestRpcApiEndpoint:
@pytest.mark.parametrize(
"input_params, method_result, result_cmp_func, method_params",
[
- ("", None, equals, {}),
+ ("", None, lambda got, _: got == b"", {}),
("", "test_me", equals, {}),
(
json.dumps(BaseSerialization.serialize({"dag_id": 15,
"task_id": "fake-task"})),
@@ -103,8 +103,7 @@ class TestRpcApiEndpoint:
],
)
def test_method(self, input_params, method_result, result_cmp_func,
method_params):
- if method_result:
- mock_test_method.return_value = method_result
+ mock_test_method.return_value = method_result
input_data = {
"jsonrpc": "2.0",
@@ -119,7 +118,10 @@ class TestRpcApiEndpoint:
assert response.status_code == 200
if method_result:
response_data =
BaseSerialization.deserialize(json.loads(response.data),
use_pydantic_models=True)
- assert result_cmp_func(response_data, method_result)
+ else:
+ response_data = response.data
+
+ assert result_cmp_func(response_data, method_result)
mock_test_method.assert_called_once_with(**method_params)
diff --git a/tests/api_internal/test_internal_api_call.py
b/tests/api_internal/test_internal_api_call.py
index c3eda6e184..9d96fa0756 100644
--- a/tests/api_internal/test_internal_api_call.py
+++ b/tests/api_internal/test_internal_api_call.py
@@ -142,6 +142,23 @@ class TestInternalApiCall:
headers={"Content-Type": "application/json"},
)
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ @mock.patch("airflow.api_internal.internal_api_call.requests")
+ def test_remote_call_with_none_result(self, mock_requests):
+ response = requests.Response()
+ response.status_code = 200
+ response._content = b""
+
+ mock_requests.post.return_value = response
+
+ result = TestInternalApiCall.fake_method()
+ assert result is None
+
@conf_vars(
{
("core", "database_access_isolation"): "true",