This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5ff411bf5a4 Convert exceptions raised in Flask application to fastapi
exceptions (#45625)
5ff411bf5a4 is described below
commit 5ff411bf5a47dc5df1d770894f8200685ac0dfd9
Author: Vincent <[email protected]>
AuthorDate: Tue Jan 14 10:31:12 2025 -0500
Convert exceptions raised in Flask application to fastapi exceptions
(#45625)
---
airflow/api_fastapi/app.py | 2 ++
airflow/api_fastapi/core_api/app.py | 5 +++
airflow/api_fastapi/core_api/middleware.py | 39 ++++++++++++++++++++
airflow/api_fastapi/core_api/security.py | 5 +--
.../providers/fab/www/api_connexion/exceptions.py | 41 ++--------------------
.../providers/fab/www/extensions/init_views.py | 34 ++++++++++++++++--
providers/src/airflow/providers/fab/www/views.py | 13 -------
7 files changed, 82 insertions(+), 57 deletions(-)
diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py
index 9cbb190d411..ff74deb2fee 100644
--- a/airflow/api_fastapi/app.py
+++ b/airflow/api_fastapi/app.py
@@ -28,6 +28,7 @@ from airflow.api_fastapi.core_api.app import (
init_dag_bag,
init_error_handlers,
init_flask_plugins,
+ init_middlewares,
init_plugins,
init_views,
)
@@ -74,6 +75,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_auth_manager(app)
init_flask_plugins(app)
init_error_handlers(app)
+ init_middlewares(app)
if "execution" in apps_list or "all" in apps_list:
task_exec_api_app = create_task_execution_api_app(app)
diff --git a/airflow/api_fastapi/core_api/app.py
b/airflow/api_fastapi/core_api/app.py
index e94a7ea3f33..6099c5b654a 100644
--- a/airflow/api_fastapi/core_api/app.py
+++ b/airflow/api_fastapi/core_api/app.py
@@ -30,6 +30,7 @@ from starlette.responses import HTMLResponse
from starlette.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
+from airflow.api_fastapi.core_api.middleware import FlaskExceptionsMiddleware
from airflow.exceptions import AirflowException
from airflow.settings import AIRFLOW_PATH
from airflow.www.extensions.init_dagbag import get_dag_bag
@@ -165,3 +166,7 @@ def init_error_handlers(app: FastAPI) -> None:
# register database error handlers
for handler in DatabaseErrorHandlers:
app.add_exception_handler(handler.exception_cls,
handler.exception_handler)
+
+
+def init_middlewares(app: FastAPI) -> None:
+ app.add_middleware(FlaskExceptionsMiddleware)
diff --git a/airflow/api_fastapi/core_api/middleware.py
b/airflow/api_fastapi/core_api/middleware.py
new file mode 100644
index 00000000000..e88c9acc543
--- /dev/null
+++ b/airflow/api_fastapi/core_api/middleware.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from fastapi import HTTPException, Request
+from starlette.middleware.base import BaseHTTPMiddleware
+
+
+# Custom Middleware Class
+class FlaskExceptionsMiddleware(BaseHTTPMiddleware):
+ """Middleware that converts exceptions thrown in the Flask application to
Fastapi exceptions."""
+
+ async def dispatch(self, request: Request, call_next):
+ response = await call_next(request)
+
+ # Check if the WSGI response contains an error
+ if response.status_code >= 400 and response.media_type ==
"application/json":
+ body = await response.json()
+ if "error" in body:
+ # Transform the WSGI app's exception into a FastAPI
HTTPException
+ raise HTTPException(
+ status_code=response.status_code,
+ detail=body["error"],
+ )
+ return response
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
index 7aaee4d0fae..baa852a4c54 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Callable
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jwt import InvalidTokenError
+from starlette import status
from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.models.base_user import BaseUser
@@ -50,7 +51,7 @@ def get_user(token_str: Annotated[str,
Depends(oauth2_scheme)]) -> BaseUser:
payload: dict[str, Any] = signer.verify_token(token_str)
return get_auth_manager().deserialize_user(payload)
except InvalidTokenError:
- raise HTTPException(403, "Forbidden")
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity
| None = None) -> Callable:
@@ -75,4 +76,4 @@ def _requires_access(
is_authorized_callback: Callable[[], bool],
) -> None:
if not is_authorized_callback():
- raise HTTPException(403, "Forbidden")
+ raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
diff --git
a/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
index ef2e2ab9b4b..feaf38e8fd0 100644
--- a/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
+++ b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
@@ -17,16 +17,12 @@
from __future__ import annotations
from http import HTTPStatus
-from typing import TYPE_CHECKING, Any
+from typing import Any
-import werkzeug
-from connexion import FlaskApi, ProblemException, problem
+from connexion import ProblemException
from airflow.utils.docs import get_docs_url
-if TYPE_CHECKING:
- import flask
-
doc_link = get_docs_url("stable-rest-api-ref.html")
EXCEPTIONS_LINK_MAP = {
@@ -40,39 +36,6 @@ EXCEPTIONS_LINK_MAP = {
}
-def common_error_handler(exception: BaseException) -> flask.Response:
- """Use to capture connexion exceptions and add link to the type field."""
- if isinstance(exception, ProblemException):
- link = EXCEPTIONS_LINK_MAP.get(exception.status)
- if link:
- response = problem(
- status=exception.status,
- title=exception.title,
- detail=exception.detail,
- type=link,
- instance=exception.instance,
- headers=exception.headers,
- ext=exception.ext,
- )
- else:
- response = problem(
- status=exception.status,
- title=exception.title,
- detail=exception.detail,
- type=exception.type,
- instance=exception.instance,
- headers=exception.headers,
- ext=exception.ext,
- )
- else:
- if not isinstance(exception, werkzeug.exceptions.HTTPException):
- exception = werkzeug.exceptions.InternalServerError()
-
- response = problem(title=exception.name, detail=exception.description,
status=exception.code)
-
- return FlaskApi.get_response(response)
-
-
class NotFound(ProblemException):
"""Raise when the object cannot be found."""
diff --git a/providers/src/airflow/providers/fab/www/extensions/init_views.py
b/providers/src/airflow/providers/fab/www/extensions/init_views.py
index 382bcaf9ca7..588276735fc 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_views.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_views.py
@@ -23,6 +23,15 @@ from typing import TYPE_CHECKING
from connexion import Resolver
from connexion.decorators.validation import RequestBodyValidator
from connexion.exceptions import BadRequestProblem
+from flask import jsonify
+from starlette import status
+
+from airflow.providers.fab.www.api_connexion.exceptions import (
+ BadRequest,
+ NotFound,
+ PermissionDenied,
+ Unauthenticated,
+)
if TYPE_CHECKING:
from flask import Flask
@@ -114,7 +123,26 @@ def init_plugins(app):
def init_error_handlers(app: Flask):
"""Add custom errors handlers."""
- from airflow.providers.fab.www import views
- app.register_error_handler(500, views.show_traceback)
- app.register_error_handler(404, views.not_found)
+ def handle_bad_request(error):
+ response = {"error": "Bad request"}
+ return jsonify(response), status.HTTP_400_BAD_REQUEST
+
+ def handle_not_found(error):
+ response = {"error": "Not found"}
+ return jsonify(response), status.HTTP_404_NOT_FOUND
+
+ def handle_unauthenticated(error):
+ response = {"error": "User is not authenticated"}
+ return jsonify(response), status.HTTP_401_UNAUTHORIZED
+
+ def handle_denied(error):
+ response = {"error": "Access is denied"}
+ return jsonify(response), status.HTTP_403_FORBIDDEN
+
+ app.register_error_handler(404, handle_not_found)
+
+ app.register_error_handler(BadRequest, handle_bad_request)
+ app.register_error_handler(NotFound, handle_not_found)
+ app.register_error_handler(Unauthenticated, handle_unauthenticated)
+ app.register_error_handler(PermissionDenied, handle_denied)
diff --git a/providers/src/airflow/providers/fab/www/views.py
b/providers/src/airflow/providers/fab/www/views.py
index 48bf0bfddff..43ac276897e 100644
--- a/providers/src/airflow/providers/fab/www/views.py
+++ b/providers/src/airflow/providers/fab/www/views.py
@@ -30,19 +30,6 @@ from airflow.utils.net import get_hostname
from airflow.version import version
-def not_found(error):
- """Show Not Found on screen for any error in the Webserver."""
- return (
- render_template(
- "airflow/error.html",
- hostname=get_hostname() if conf.getboolean("webserver",
"EXPOSE_HOSTNAME") else "",
- status_code=404,
- error_message="Page cannot be found.",
- ),
- 404,
- )
-
-
def show_traceback(error):
"""Show Traceback for a given error."""
is_logged_in = get_auth_manager().is_logged_in()