This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7fc442e288 [AIP-44] Add internal API definition. (#27892)
7fc442e288 is described below
commit 7fc442e288b2cf65de120cd66ffffd7ae891e0cf
Author: mhenc <[email protected]>
AuthorDate: Thu Dec 8 13:23:35 2022 +0100
[AIP-44] Add internal API definition. (#27892)
---
airflow/api_internal/__init__.py | 16 +++
airflow/api_internal/endpoints/__init__.py | 16 +++
airflow/api_internal/endpoints/rpc_api_endpoint.py | 80 +++++++++++
airflow/api_internal/internal_api_call.py | 114 ++++++++++++++++
airflow/api_internal/openapi/internal_api_v1.yaml | 91 +++++++++++++
airflow/config_templates/config.yml | 20 +++
airflow/config_templates/default_airflow.cfg | 10 ++
airflow/dag_processing/processor.py | 34 +++--
airflow/www/app.py | 3 +
airflow/www/extensions/init_views.py | 21 +++
tests/api_internal/__init__.py | 16 +++
tests/api_internal/endpoints/__init__.py | 16 +++
.../endpoints/test_rpc_api_endpoint.py | 124 +++++++++++++++++
tests/api_internal/test_internal_api_call.py | 151 +++++++++++++++++++++
tests/test_utils/decorators.py | 1 +
15 files changed, 702 insertions(+), 11 deletions(-)
diff --git a/airflow/api_internal/__init__.py b/airflow/api_internal/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/api_internal/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/api_internal/endpoints/__init__.py
b/airflow/api_internal/endpoints/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/api_internal/endpoints/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
new file mode 100644
index 0000000000..90bb23e112
--- /dev/null
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import logging
+
+from flask import Response
+
+from airflow.api_connexion.types import APIResponse
+from airflow.dag_processing.processor import DagFileProcessor
+from airflow.serialization.serialized_objects import BaseSerialization
+
+log = logging.getLogger(__name__)
+
+
+def _build_methods_map(list) -> dict:
+ return {f"{func.__module__}.{func.__name__}": func for func in list}
+
+
+METHODS_MAP = _build_methods_map(
+ [
+ DagFileProcessor.update_import_errors,
+ ]
+)
+
+
+def internal_airflow_api(
+ body: dict,
+) -> APIResponse:
+ """Handler for Internal API /internal_api/v1/rpcapi endpoint."""
+ log.debug("Got request")
+ json_rpc = body.get("jsonrpc")
+ if json_rpc != "2.0":
+ log.error("Not jsonrpc-2.0 request.")
+ return Response(response="Expected jsonrpc 2.0 request.", status=400)
+
+ method_name = str(body.get("method"))
+ if method_name not in METHODS_MAP:
+ log.error("Unrecognized method: %s.", method_name)
+ return Response(response=f"Unrecognized method: {method_name}.",
status=400)
+
+ handler = METHODS_MAP[method_name]
+ params = {}
+ try:
+ if body.get("params"):
+ params_json = json.loads(str(body.get("params")))
+ params = BaseSerialization.deserialize(params_json)
+ except Exception as err:
+ log.error("Error deserializing parameters.")
+ log.error(err)
+ return Response(response="Error deserializing parameters.", status=400)
+
+ log.debug("Calling method %.", {method_name})
+ try:
+ output = handler(**params)
+ output_json = BaseSerialization.serialize(output)
+ log.debug("Returning response")
+ return Response(
+ response=json.dumps(output_json or "{}"), headers={"Content-Type":
"application/json"}
+ )
+ except Exception as e:
+ log.error("Error when calling method %s.", method_name)
+ log.error(e)
+ return Response(response=f"Error executing method: {method_name}.",
status=500)
diff --git a/airflow/api_internal/internal_api_call.py
b/airflow/api_internal/internal_api_call.py
new file mode 100644
index 0000000000..9de4d33c86
--- /dev/null
+++ b/airflow/api_internal/internal_api_call.py
@@ -0,0 +1,114 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+import json
+from functools import wraps
+from typing import Callable, TypeVar
+
+import requests
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException, AirflowException
+from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.typing_compat import ParamSpec
+
+PS = ParamSpec("PS")
+RT = TypeVar("RT")
+
+
+class InternalApiConfig:
+ """Stores and caches configuration for Internal API."""
+
+ _initialized = False
+ _use_internal_api = False
+ _internal_api_endpoint = ""
+
+ @staticmethod
+ def get_use_internal_api():
+ if not InternalApiConfig._initialized:
+ InternalApiConfig._init_values()
+ return InternalApiConfig._use_internal_api
+
+ @staticmethod
+ def get_internal_api_endpoint():
+ if not InternalApiConfig._initialized:
+ InternalApiConfig._init_values()
+ return InternalApiConfig._internal_api_endpoint
+
+ @staticmethod
+ def _init_values():
+ use_internal_api = conf.getboolean("core", "database_access_isolation")
+ internal_api_endpoint = ""
+ if use_internal_api:
+ internal_api_url = conf.get("core", "internal_api_url")
+ internal_api_endpoint = internal_api_url +
"/internal_api/v1/rpcapi"
+ if not internal_api_endpoint.startswith("http://"):
+ raise AirflowConfigException("[core]internal_api_url must
start with http://")
+
+ InternalApiConfig._initialized = True
+ InternalApiConfig._use_internal_api = use_internal_api
+ InternalApiConfig._internal_api_endpoint = internal_api_endpoint
+
+
+def internal_api_call(func: Callable[PS, RT | None]) -> Callable[PS, RT |
None]:
+ """Decorator for methods which may be executed in database isolation mode.
+
+ If [core]database_access_isolation is true then such method are not
executed locally,
+ but instead RPC call is made to Database API (aka Internal API). This
makes some components
+ decouple from direct Airflow database access.
+ Each decorated method must be present in METHODS list in
airflow.api_internal.endpoints.rpc_api_endpoint.
+ Only static methods can be decorated. This decorator must be before
"provide_session".
+
+ See
[AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API)
+ for more information .
+ """
+ headers = {
+ "Content-Type": "application/json",
+ }
+
+ def make_jsonrpc_request(method_name: str, params_json: str) -> bytes:
+ data = {"jsonrpc": "2.0", "method": method_name, "params": params_json}
+ internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint()
+ response = requests.post(url=internal_api_endpoint,
data=json.dumps(data), headers=headers)
+ if response.status_code != 200:
+ raise AirflowException(
+ f"Got {response.status_code}:{response.reason} when sending
the internal api request."
+ )
+ return response.content
+
+ @wraps(func)
+ def wrapper(*args, **kwargs) -> RT | None:
+ use_internal_api = InternalApiConfig.get_use_internal_api()
+ if not use_internal_api:
+ return func(*args, **kwargs)
+
+ bound = inspect.signature(func).bind(*args, **kwargs)
+ arguments_dict = dict(bound.arguments)
+ if "session" in arguments_dict:
+ del arguments_dict["session"]
+ args_json = json.dumps(BaseSerialization.serialize(arguments_dict))
+ method_name = f"{func.__module__}.{func.__name__}"
+ result = make_jsonrpc_request(method_name, args_json)
+ if result:
+ return BaseSerialization.deserialize(json.loads(result))
+ else:
+ return None
+
+ return wrapper
diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml
b/airflow/api_internal/openapi/internal_api_v1.yaml
new file mode 100644
index 0000000000..58ef962179
--- /dev/null
+++ b/airflow/api_internal/openapi/internal_api_v1.yaml
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+openapi: 3.0.2
+info:
+ title: Airflow Internal API
+ version: 1.0.0
+ description: |
+ This is Airflow Internal API - which is a proxy for components running
+ customer code for connecting to Airflow Database.
+
+ It is not intended to be used by any external code.
+
+ You can find more information in AIP-44
+
https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API
+
+
+servers:
+ - url: /internal_api/v1
+ description: Airflow Internal API
+paths:
+ "/rpcapi":
+ post:
+ operationId: rpcapi
+ deprecated: false
+ x-openapi-router-controller:
airflow.api_internal.endpoints.rpc_api_endpoint
+ operationId: internal_airflow_api
+ tags:
+ - JSONRPC
+ parameters: []
+ responses:
+ '200':
+ description: Successful response
+ requestBody:
+ x-body-name: body
+ required: true
+ content:
+ application/json:
+ schema:
+ type: object
+ required:
+ - method
+ - jsonrpc
+ - params
+ properties:
+ jsonrpc:
+ type: string
+ default: '2.0'
+ description: JSON-RPC Version (2.0)
+ method:
+ type: string
+ description: Method name
+ params:
+ title: Parameters
+ type: string
+x-headers: []
+x-explorer-enabled: true
+x-proxy-enabled: true
+components:
+ schemas:
+ JsonRpcRequired:
+ type: object
+ required:
+ - method
+ - jsonrpc
+ properties:
+ method:
+ type: string
+ description: Method name
+ jsonrpc:
+ type: string
+ default: '2.0'
+ description: JSON-RPC Version (2.0)
+ discriminator:
+ propertyName: method_name
+tags: []
diff --git a/airflow/config_templates/config.yml
b/airflow/config_templates/config.yml
index b15981aab9..09b02fa84b 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -418,6 +418,19 @@
type: string
default: ~
example: '{"some_param": "some_value"}'
+ - name: database_access_isolation
+ description: (experimental) Whether components should use Airflow
Internal API for DB connectivity.
+ version_added: 2.6.0
+ type: boolean
+ example: ~
+ default: "False"
+ - name: internal_api_url
+ description: |
+ (experimental)Airflow Internal API url. Only used if [core]
database_access_isolation is True.
+ version_added: 2.6.0
+ type: string
+ default: ~
+ example: 'http://localhost:8080'
- name: database
description: ~
@@ -1482,6 +1495,13 @@
type: string
example: "dagrun_cleared,failed"
default: ~
+ - name: run_internal_api
+ description: |
+ Boolean for running Internal API in the webserver.
+ version_added: 2.6.0
+ type: boolean
+ example: ~
+ default: "False"
- name: email
description: |
diff --git a/airflow/config_templates/default_airflow.cfg
b/airflow/config_templates/default_airflow.cfg
index 146d0077a9..1dfd0e98d2 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -240,6 +240,13 @@ daemon_umask = 0o077
# Example: dataset_manager_kwargs = {{"some_param": "some_value"}}
# dataset_manager_kwargs =
+# (experimental) Whether components should use Airflow Internal API for DB
connectivity.
+database_access_isolation = False
+
+# (experimental)Airflow Internal API url. Only used if [core]
database_access_isolation is True.
+# Example: internal_api_url = http://localhost:8080
+# internal_api_url =
+
[database]
# The SqlAlchemy connection string to the metadata database.
# SqlAlchemy supports many different database engines.
@@ -752,6 +759,9 @@ audit_view_excluded_events =
gantt,landing_times,tries,duration,calendar,graph,g
# Example: audit_view_included_events = dagrun_cleared,failed
# audit_view_included_events =
+# Boolean for running Internal API in the webserver.
+run_internal_api = False
+
[email]
# Configuration email backend and whether to
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index 323ff94c5e..fbb2bd298d 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import datetime
import logging
import multiprocessing
import os
@@ -24,7 +23,7 @@ import signal
import threading
import time
from contextlib import redirect_stderr, redirect_stdout, suppress
-from datetime import timedelta
+from datetime import datetime, timedelta
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import TYPE_CHECKING, Iterator
@@ -33,6 +32,7 @@ from sqlalchemy import exc, func, or_
from sqlalchemy.orm.session import Session
from airflow import settings
+from airflow.api_internal.internal_api_call import internal_api_call
from airflow.callbacks.callback_requests import (
CallbackRequest,
DagCallbackRequest,
@@ -94,7 +94,7 @@ class DagFileProcessorProcess(LoggingMixin,
MultiprocessingStartMethodMixin):
# Whether the process is done running.
self._done = False
# When the process started.
- self._start_time: datetime.datetime | None = None
+ self._start_time: datetime | None = None
# This ID is use to uniquely name the process / thread that's launched
# by this processor instance
self._instance_id = DagFileProcessorProcess.class_creation_counter
@@ -327,7 +327,7 @@ class DagFileProcessorProcess(LoggingMixin,
MultiprocessingStartMethodMixin):
return self._result
@property
- def start_time(self) -> datetime.datetime:
+ def start_time(self) -> datetime:
"""Time when this started to process the file."""
if self._start_time is None:
raise AirflowException("Tried to get start time before it
started!")
@@ -448,7 +448,7 @@ class DagFileProcessor(LoggingMixin):
.all()
)
if slas:
- sla_dates: list[datetime.datetime] = [sla.execution_date for sla
in slas]
+ sla_dates: list[datetime] = [sla.execution_date for sla in slas]
fetched_tis: list[TI] = (
session.query(TI)
.filter(TI.state != State.SUCCESS,
TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
@@ -524,17 +524,21 @@ class DagFileProcessor(LoggingMixin):
session.commit()
@staticmethod
- def update_import_errors(session: Session, dagbag: DagBag) -> None:
+ @internal_api_call
+ @provide_session
+ def update_import_errors(
+ file_last_changed: dict[str, datetime], import_errors: dict[str, str],
session: Session = NEW_SESSION
+ ) -> None:
"""
Update any import errors to be displayed in the UI.
For the DAGs in the given DagBag, record any associated import errors
and clears
errors for files that no longer have them. These are usually displayed
through the
Airflow UI so that users know that there are issues parsing DAGs.
- :param session: session for ORM operations
:param dagbag: DagBag containing DAGs with import errors
+ :param session: session for ORM operations
"""
- files_without_error = dagbag.file_last_changed -
dagbag.import_errors.keys()
+ files_without_error = file_last_changed - import_errors.keys()
# Clear the errors of the processed files
# that no longer have errors
@@ -547,7 +551,7 @@ class DagFileProcessor(LoggingMixin):
existing_import_error_files = [x.filename for x in
session.query(errors.ImportError.filename).all()]
# Add the errors of the processed files
- for filename, stacktrace in dagbag.import_errors.items():
+ for filename, stacktrace in import_errors.items():
if filename in existing_import_error_files:
session.query(errors.ImportError).filter(errors.ImportError.filename ==
filename).update(
dict(filename=filename, timestamp=timezone.utcnow(),
stacktrace=stacktrace),
@@ -754,7 +758,11 @@ class DagFileProcessor(LoggingMixin):
self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(),
file_path)
else:
self.log.warning("No viable dags retrieved from %s", file_path)
- self.update_import_errors(session, dagbag)
+ DagFileProcessor.update_import_errors(
+ file_last_changed=dagbag.file_last_changed,
+ import_errors=dagbag.import_errors,
+ session=session,
+ )
if callback_requests:
# If there were callback requests for this file but there was a
# parse error we still need to progress the state of TIs,
@@ -781,7 +789,11 @@ class DagFileProcessor(LoggingMixin):
# Record import errors into the ORM
try:
- self.update_import_errors(session, dagbag)
+ DagFileProcessor.update_import_errors(
+ file_last_changed=dagbag.file_last_changed,
+ import_errors=dagbag.import_errors,
+ session=session,
+ )
except Exception:
self.log.exception("Error logging import errors!")
diff --git a/airflow/www/app.py b/airflow/www/app.py
index 6a64401e1c..19d2831dfd 100644
--- a/airflow/www/app.py
+++ b/airflow/www/app.py
@@ -48,6 +48,7 @@ from airflow.www.extensions.init_session import
init_airflow_session_interface
from airflow.www.extensions.init_views import (
init_api_connexion,
init_api_experimental,
+ init_api_internal,
init_appbuilder_views,
init_connection_form,
init_error_handlers,
@@ -149,6 +150,8 @@ def create_app(config=None, testing=False):
init_connection_form()
init_error_handlers(flask_app)
init_api_connexion(flask_app)
+ if conf.getboolean("webserver", "run_internal_api", fallback=False):
+ init_api_internal(flask_app)
init_api_experimental(flask_app)
sync_appbuilder_roles(flask_app)
diff --git a/airflow/www/extensions/init_views.py
b/airflow/www/extensions/init_views.py
index 25d5d5898c..86f94d2f22 100644
--- a/airflow/www/extensions/init_views.py
+++ b/airflow/www/extensions/init_views.py
@@ -220,6 +220,27 @@ def init_api_connexion(app: Flask) -> None:
app.extensions["csrf"].exempt(api_bp)
+def init_api_internal(app: Flask) -> None:
+ """Initialize Internal API"""
+ if not conf.getboolean("webserver", "run_internal_api", fallback=False):
+ return
+ base_path = "/internal_api/v1"
+
+ spec_dir = path.join(ROOT_APP_DIR, "api_internal", "openapi")
+ internal_app = App(__name__, specification_dir=spec_dir,
skip_error_handlers=True)
+ internal_app.app = app
+ api_bp = internal_app.add_api(
+ specification="internal_api_v1.yaml",
+ base_path=base_path,
+ validate_responses=True,
+ strict_validation=True,
+ ).blueprint
+ # Like "api_bp.after_request", but the BP is already registered, so we have
+ # to register it in the app directly.
+ app.after_request_funcs.setdefault(api_bp.name,
[]).append(set_cors_headers_on_response)
+ app.extensions["csrf"].exempt(api_bp)
+
+
def init_api_experimental(app):
"""Initialize Experimental API"""
if not conf.getboolean("api", "enable_experimental_api", fallback=False):
diff --git a/tests/api_internal/__init__.py b/tests/api_internal/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/api_internal/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/api_internal/endpoints/__init__.py
b/tests/api_internal/endpoints/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/api_internal/endpoints/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py
b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
new file mode 100644
index 0000000000..68f22fe6cc
--- /dev/null
+++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from unittest import mock
+
+import pytest
+from flask import Flask
+
+from airflow.api_internal.endpoints import rpc_api_endpoint
+from airflow.serialization.serialized_objects import BaseSerialization
+from airflow.www import app
+from tests.test_utils.config import conf_vars
+from tests.test_utils.decorators import dont_initialize_flask_app_submodules
+
+TEST_METHOD_NAME = "test_method"
+
+mock_test_method = mock.MagicMock()
+
+
[email protected](scope="session")
+def minimal_app_for_internal_api() -> Flask:
+ @dont_initialize_flask_app_submodules(
+ skip_all_except=[
+ "init_appbuilder",
+ "init_api_internal",
+ ]
+ )
+ def factory() -> Flask:
+ with conf_vars({("webserver", "run_internal_api"): "true"}):
+ return app.create_app(testing=True, config={"WTF_CSRF_ENABLED":
False}) # type:ignore
+
+ return factory()
+
+
+class TestRpcApiEndpoint:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None:
+ rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method
+ self.app = minimal_app_for_internal_api
+ self.client = self.app.test_client() # type:ignore
+ mock_test_method.reset_mock()
+ mock_test_method.side_effect = None
+
+ @pytest.mark.parametrize(
+ "input_data, method_result, method_params, expected_code",
+ [
+ ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
"test_me", None, 200),
+ ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
None, None, 200),
+ (
+ {
+ "jsonrpc": "2.0",
+ "method": TEST_METHOD_NAME,
+ "params":
json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})),
+ },
+ ("dag_id_15", "fake-task", 1),
+ {"dag_id": 15, "task_id": "fake-task"},
+ 200,
+ ),
+ ],
+ )
+ def test_method(self, input_data, method_result, method_params,
expected_code):
+ if method_result:
+ mock_test_method.return_value = method_result
+
+ response = self.client.post(
+ "/internal_api/v1/rpcapi",
+ headers={"Content-Type": "application/json"},
+ data=json.dumps(input_data),
+ )
+ assert response.status_code == expected_code
+ if method_result:
+ response_data =
BaseSerialization.deserialize(json.loads(response.data))
+ assert response_data == method_result
+ if method_params:
+ mock_test_method.assert_called_once_with(**method_params)
+ else:
+ mock_test_method.assert_called_once()
+
+ def test_method_with_exception(self):
+ mock_test_method.side_effect = ValueError("Error!!!")
+ data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}
+
+ response = self.client.post(
+ "/internal_api/v1/rpcapi", headers={"Content-Type":
"application/json"}, data=json.dumps(data)
+ )
+ assert response.status_code == 500
+ assert response.data, b"Error executing method: test_method."
+ mock_test_method.assert_called_once()
+
+ def test_unknown_method(self):
+ data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist",
"params": ""}
+
+ response = self.client.post(
+ "/internal_api/v1/rpcapi", headers={"Content-Type":
"application/json"}, data=json.dumps(data)
+ )
+ assert response.status_code == 400
+ assert response.data == b"Unrecognized method:
i-bet-it-does-not-exist."
+ mock_test_method.assert_not_called()
+
+ def test_invalid_jsonrpc(self):
+ data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": ""}
+
+ response = self.client.post(
+ "/internal_api/v1/rpcapi", headers={"Content-Type":
"application/json"}, data=json.dumps(data)
+ )
+ assert response.status_code == 400
+ assert response.data == b"Expected jsonrpc 2.0 request."
+ mock_test_method.assert_not_called()
diff --git a/tests/api_internal/test_internal_api_call.py
b/tests/api_internal/test_internal_api_call.py
new file mode 100644
index 0000000000..579a7720cc
--- /dev/null
+++ b/tests/api_internal/test_internal_api_call.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from __future__ import annotations
+
+import json
+import unittest
+from unittest import mock
+
+import requests
+
+from airflow.api_internal.internal_api_call import InternalApiConfig,
internal_api_call
+from airflow.serialization.serialized_objects import BaseSerialization
+from tests.test_utils.config import conf_vars
+
+
+class TestInternalApiConfig(unittest.TestCase):
+ def setUp(self):
+ InternalApiConfig._initialized = False
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "false",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ def test_get_use_internal_api_disabled(self):
+ self.assertFalse(InternalApiConfig.get_use_internal_api())
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ def test_get_use_internal_api_enabled(self):
+ self.assertTrue(InternalApiConfig.get_use_internal_api())
+ self.assertEqual(
+ InternalApiConfig.get_internal_api_endpoint(),
+ "http://localhost:8888/internal_api/v1/rpcapi",
+ )
+
+
+@internal_api_call
+def fake_method() -> str:
+ return "local-call"
+
+
+@internal_api_call
+def fake_method_with_params(dag_id: str, task_id: int) -> str:
+ return f"local-call-with-params-{dag_id}-{task_id}"
+
+
+class TestInternalApiCall(unittest.TestCase):
+ def setUp(self):
+ InternalApiConfig._initialized = False
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "false",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ @mock.patch("airflow.api_internal.internal_api_call.requests")
+ def test_local_call(self, mock_requests):
+ result = fake_method()
+
+ self.assertEqual(result, "local-call")
+ mock_requests.post.assert_not_called()
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ @mock.patch("airflow.api_internal.internal_api_call.requests")
+ def test_remote_call(self, mock_requests):
+ response = requests.Response()
+ response.status_code = 200
+
+ response._content =
json.dumps(BaseSerialization.serialize("remote-call"))
+
+ mock_requests.post.return_value = response
+
+ result = fake_method()
+ self.assertEqual(result, "remote-call")
+ expected_data = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method":
"tests.api_internal.test_internal_api_call.fake_method",
+ "params": json.dumps(BaseSerialization.serialize({})),
+ }
+ )
+ mock_requests.post.assert_called_once_with(
+ url="http://localhost:8888/internal_api/v1/rpcapi",
+ data=expected_data,
+ headers={"Content-Type": "application/json"},
+ )
+
+ @conf_vars(
+ {
+ ("core", "database_access_isolation"): "true",
+ ("core", "internal_api_url"): "http://localhost:8888",
+ }
+ )
+ @mock.patch("airflow.api_internal.internal_api_call.requests")
+ def test_remote_call_with_params(self, mock_requests):
+ response = requests.Response()
+ response.status_code = 200
+
+ response._content =
json.dumps(BaseSerialization.serialize("remote-call"))
+
+ mock_requests.post.return_value = response
+
+ result = fake_method_with_params("fake-dag", task_id=123)
+ self.assertEqual(result, "remote-call")
+ expected_data = json.dumps(
+ {
+ "jsonrpc": "2.0",
+ "method":
"tests.api_internal.test_internal_api_call.fake_method_with_params",
+ "params": json.dumps(
+ BaseSerialization.serialize(
+ {
+ "dag_id": "fake-dag",
+ "task_id": 123,
+ }
+ )
+ ),
+ }
+ )
+ mock_requests.post.assert_called_once_with(
+ url="http://localhost:8888/internal_api/v1/rpcapi",
+ data=expected_data,
+ headers={"Content-Type": "application/json"},
+ )
diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py
index d0b71b502c..7d809834da 100644
--- a/tests/test_utils/decorators.py
+++ b/tests/test_utils/decorators.py
@@ -39,6 +39,7 @@ def dont_initialize_flask_app_submodules(_func=None, *,
skip_all_except=None):
"init_connection_form",
"init_error_handlers",
"init_api_connexion",
+ "init_api_internal",
"init_api_experimental",
"sync_appbuilder_roles",
"init_jinja_globals",