vincbeck commented on code in PR #27892: URL: https://github.com/apache/airflow/pull/27892#discussion_r1036456952
########## airflow/api_internal/endpoints/rpc_api_endpoint.py: ########## @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging + +from flask import Response + +from airflow.api_connexion.types import APIResponse +from airflow.dag_processing.processor import DagFileProcessor +from airflow.serialization.serialized_objects import BaseSerialization + +log = logging.getLogger(__name__) + + +def _build_methods_map(list) -> dict: + return {f"{func.__module__}.{func.__name__}": func for func in list} + + +METHODS_MAP = _build_methods_map( Review Comment: Love that ########## airflow/api_internal/endpoints/rpc_api_endpoint.py: ########## @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging + +from flask import Response + +from airflow.api_connexion.types import APIResponse +from airflow.dag_processing.processor import DagFileProcessor +from airflow.serialization.serialized_objects import BaseSerialization + +log = logging.getLogger(__name__) + + +def _build_methods_map(list) -> dict: + return {f"{func.__module__}.{func.__name__}": func for func in list} + + +METHODS_MAP = _build_methods_map( + [ + DagFileProcessor.update_import_errors, + ] +) + + +def json_rpc( + body: dict, +) -> APIResponse: + """Handler for Internal API /internal/v1/rpcapi endpoint.""" + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + log.error("Not jsonrpc-2.0 request.") + return Response(response="Expected jsonrpc 2.0 request.", status=400) + + method_name = str(body.get("method")) + if method_name not in METHODS_MAP: + log.error("Unrecognized method: %s.", method_name) + return Response(response=f"Unrecognized method: {method_name}.", status=400) + + handler = METHODS_MAP[method_name] + try: + params = {} Review Comment: super nit. I just feel it makes more sense to have it outside of the try statement. Feel free to ignore if you disagree ```suggestion params = {} try: ``` ########## airflow/api_internal/endpoints/rpc_api_endpoint.py: ########## @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging + +from flask import Response + +from airflow.api_connexion.types import APIResponse +from airflow.dag_processing.processor import DagFileProcessor +from airflow.serialization.serialized_objects import BaseSerialization + +log = logging.getLogger(__name__) + + +def _build_methods_map(list) -> dict: + return {f"{func.__module__}.{func.__name__}": func for func in list} + + +METHODS_MAP = _build_methods_map( + [ + DagFileProcessor.update_import_errors, + ] +) + + +def json_rpc( + body: dict, +) -> APIResponse: + """Handler for Internal API /internal/v1/rpcapi endpoint.""" + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + log.error("Not jsonrpc-2.0 request.") + return Response(response="Expected jsonrpc 2.0 request.", status=400) + + method_name = str(body.get("method")) + if method_name not in METHODS_MAP: + log.error("Unrecognized method: %s.", method_name) + return Response(response=f"Unrecognized method: {method_name}.", status=400) + + handler = METHODS_MAP[method_name] + try: + params = {} + if body.get("params"): + params_json = json.loads(str(body.get("params"))) + params = BaseSerialization.deserialize(params_json) + except Exception as err: + log.error("Error deserializing parameters.") + log.error(err) + return Response(response="Error deserializing parameters.", status=400) + + log.debug("Calling method %.", {method_name}) + try: + handler = METHODS_MAP[method_name] Review Comment: Duplicate of line 58 ```suggestion ``` ########## airflow/www/extensions/init_views.py: ########## @@ -220,6 +220,25 @@ def _handle_method_not_allowed(ex): app.extensions["csrf"].exempt(api_bp) +def init_api_internal(app: Flask) -> None: + """Initialize Internal API""" Review Comment: As an additional layer of security, should we check here the value `run_internal_api` and if it is `False`, we do nothing (or exception) ########## tests/api_internal/test_internal_api_call.py: ########## @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import requests + +from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.serialization.serialized_objects import BaseSerialization +from tests.test_utils.config import conf_vars + + +class TestInternalApiConfig(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_disabled(self): + self.assertFalse(InternalApiConfig.get_use_internal_api()) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_enabled(self): + self.assertTrue(InternalApiConfig.get_use_internal_api()) + self.assertEqual( + InternalApiConfig.get_internal_api_endpoint(), + "http://localhost:8888/internal/v1/rpcapi", + ) + + +@internal_api_call +def fake_method() -> str: + return "local-call" + + +@internal_api_call +def fake_method_with_params(dag_id: str, task_id: int) -> str: + return f"local-call-with-params-{dag_id}-{task_id}" + + +class TestInternalApiCall(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_local_call(self, mock_requests): + result = fake_method() + + self.assertEqual(result, "local-call") + mock_requests.post.assert_not_called() + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method() + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method", + "params": '{"__var": {}, "__type": "dict"}', + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call_with_params(self, mock_requests): Review Comment: Same as my previous comment, this test can be merged with the previous test using `@pytest.mark.parametrize(` ########## tests/api_internal/endpoints/test_rpc_api_endpoint.py: ########## @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import pytest +from flask import Flask + +from airflow.api_internal.endpoints import rpc_api_endpoint +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.www import app +from tests.test_utils.decorators import dont_initialize_flask_app_submodules + +TEST_METHOD_NAME = "test_method" + +mock_test_method = mock.MagicMock() + + [email protected](scope="session") +def minimal_app_for_internal_api() -> Flask: + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_appbuilder", + "init_api_internal", + ] + ) + def factory() -> Flask: + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + + return factory() + + +class TestRpcApiEndpoint(unittest.TestCase): + @pytest.fixture(autouse=True) + def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None: + rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method + self.app = minimal_app_for_internal_api + self.client = self.app.test_client() # type:ignore + mock_test_method.reset_mock() + mock_test_method.side_effect = None + + def test_method_without_params(self): + mock_test_method.return_value = "test_me" + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, b'"test_me"') + mock_test_method.assert_called_once() + + def test_method_without_result(self): + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 200) + mock_test_method.assert_called_once() + + def test_method_with_params(self): Review Comment: You can merge these 3 tests into one using `@pytest.mark.parametrize`. @potiuk added a bunch recently in a [separate PR](https://github.com/apache/airflow/pull/27912) (and separate topic). You can find some examples there. ########## airflow/api_internal/openapi/internal_api_v1.yaml: ########## @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +openapi: 3.0.2 +info: + title: Airflow Internal API + version: 1.0.0 + description: | + This is Airflow Internal API - which is a proxy for components running + customer code for connecting to Airflow Database. + + It is not intended to be used by any external code. + + You can find more information in AIP-44 + https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API + + +servers: + - url: /internal/v1 + description: Airflow Internal API +paths: + "/rpcapi": + post: + operationId: rpcapi + deprecated: false + x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint + operationId: json_rpc + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response + requestBody: + x-body-name: body + required: true + content: + application/json: + schema: + type: object + required: + - method + - jsonrpc + - params + properties: + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + method: + type: string + description: Method name + params: + title: Parameters + type: string +x-headers: [] +x-explorer-enabled: true +x-proxy-enabled: true +x-samples-enabled: true +components: + schemas: + JsonRpcRequired: Review Comment: This one seems not used? ########## tests/api_internal/test_internal_api_call.py: ########## @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import requests + +from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.serialization.serialized_objects import BaseSerialization +from tests.test_utils.config import conf_vars + + +class TestInternalApiConfig(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_disabled(self): + self.assertFalse(InternalApiConfig.get_use_internal_api()) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_enabled(self): + self.assertTrue(InternalApiConfig.get_use_internal_api()) + self.assertEqual( + InternalApiConfig.get_internal_api_endpoint(), + "http://localhost:8888/internal/v1/rpcapi", + ) + + +@internal_api_call +def fake_method() -> str: + return "local-call" + + +@internal_api_call +def fake_method_with_params(dag_id: str, task_id: int) -> str: + return f"local-call-with-params-{dag_id}-{task_id}" + + +class TestInternalApiCall(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_local_call(self, mock_requests): + result = fake_method() + + self.assertEqual(result, "local-call") + mock_requests.post.assert_not_called() + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method() + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method", + "params": '{"__var": {}, "__type": "dict"}', Review Comment: Where does it come from? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
