This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d0473a1357 AIP-44 Initialize methods map for Internal API RPC endpoint
in the method (#28841)
d0473a1357 is described below
commit d0473a1357a85c1c1b61a0a79310c352b8edebce
Author: mhenc <[email protected]>
AuthorDate: Thu Jan 19 23:05:11 2023 +0100
AIP-44 Initialize methods map for Internal API RPC endpoint in the method
(#28841)
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 26 +++++++++++-----------
.../endpoints/test_rpc_api_endpoint.py | 10 ++++++---
2 files changed, 20 insertions(+), 16 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 90bb23e112..bdf124345b 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -17,32 +17,30 @@
from __future__ import annotations
+import functools
import json
import logging
+from typing import Any, Callable
from flask import Response
from airflow.api_connexion.types import APIResponse
-from airflow.dag_processing.processor import DagFileProcessor
from airflow.serialization.serialized_objects import BaseSerialization
log = logging.getLogger(__name__)
-def _build_methods_map(list) -> dict:
- return {f"{func.__module__}.{func.__name__}": func for func in list}
[email protected]_cache()
+def _initialize_map() -> dict[str, Callable]:
+ from airflow.dag_processing.processor import DagFileProcessor
-
-METHODS_MAP = _build_methods_map(
- [
+ functions: list[Callable] = [
DagFileProcessor.update_import_errors,
]
-)
+ return {f"{func.__module__}.{func.__name__}": func for func in functions}
-def internal_airflow_api(
- body: dict,
-) -> APIResponse:
+def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
"""Handler for Internal API /internal_api/v1/rpcapi endpoint."""
log.debug("Got request")
json_rpc = body.get("jsonrpc")
@@ -50,12 +48,14 @@ def internal_airflow_api(
log.error("Not jsonrpc-2.0 request.")
return Response(response="Expected jsonrpc 2.0 request.", status=400)
- method_name = str(body.get("method"))
- if method_name not in METHODS_MAP:
+ methods_map = _initialize_map()
+
+ method_name = body.get("method")
+ if method_name not in methods_map:
log.error("Unrecognized method: %s.", method_name)
return Response(response=f"Unrecognized method: {method_name}.",
status=400)
- handler = METHODS_MAP[method_name]
+ handler = methods_map[method_name]
params = {}
try:
if body.get("params"):
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
index 68f22fe6cc..0c45bcbbe0 100644
--- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -17,12 +17,12 @@
from __future__ import annotations
import json
+from typing import Generator
from unittest import mock
import pytest
from flask import Flask
-from airflow.api_internal.endpoints import rpc_api_endpoint
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.www import app
from tests.test_utils.config import conf_vars
@@ -50,12 +50,16 @@ def minimal_app_for_internal_api() -> Flask:
class TestRpcApiEndpoint:
@pytest.fixture(autouse=True)
- def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None:
- rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method
+ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
self.app = minimal_app_for_internal_api
self.client = self.app.test_client() # type:ignore
mock_test_method.reset_mock()
mock_test_method.side_effect = None
+ with mock.patch(
+ "airflow.api_internal.endpoints.rpc_api_endpoint._initialize_map"
+ ) as mock_initialize_map:
+ mock_initialize_map.return_value = {TEST_METHOD_NAME:
mock_test_method}
+ yield mock_initialize_map
@pytest.mark.parametrize(
"input_data, method_result, method_params, expected_code",