This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ff411bf5a4 Convert exceptions raised in Flask application to fastapi 
exceptions (#45625)
5ff411bf5a4 is described below

commit 5ff411bf5a47dc5df1d770894f8200685ac0dfd9
Author: Vincent <[email protected]>
AuthorDate: Tue Jan 14 10:31:12 2025 -0500

    Convert exceptions raised in Flask application to fastapi exceptions 
(#45625)
---
 airflow/api_fastapi/app.py                         |  2 ++
 airflow/api_fastapi/core_api/app.py                |  5 +++
 airflow/api_fastapi/core_api/middleware.py         | 39 ++++++++++++++++++++
 airflow/api_fastapi/core_api/security.py           |  5 +--
 .../providers/fab/www/api_connexion/exceptions.py  | 41 ++--------------------
 .../providers/fab/www/extensions/init_views.py     | 34 ++++++++++++++++--
 providers/src/airflow/providers/fab/www/views.py   | 13 -------
 7 files changed, 82 insertions(+), 57 deletions(-)

diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py
index 9cbb190d411..ff74deb2fee 100644
--- a/airflow/api_fastapi/app.py
+++ b/airflow/api_fastapi/app.py
@@ -28,6 +28,7 @@ from airflow.api_fastapi.core_api.app import (
     init_dag_bag,
     init_error_handlers,
     init_flask_plugins,
+    init_middlewares,
     init_plugins,
     init_views,
 )
@@ -74,6 +75,7 @@ def create_app(apps: str = "all") -> FastAPI:
         init_auth_manager(app)
         init_flask_plugins(app)
         init_error_handlers(app)
+        init_middlewares(app)
 
     if "execution" in apps_list or "all" in apps_list:
         task_exec_api_app = create_task_execution_api_app(app)
diff --git a/airflow/api_fastapi/core_api/app.py 
b/airflow/api_fastapi/core_api/app.py
index e94a7ea3f33..6099c5b654a 100644
--- a/airflow/api_fastapi/core_api/app.py
+++ b/airflow/api_fastapi/core_api/app.py
@@ -30,6 +30,7 @@ from starlette.responses import HTMLResponse
 from starlette.staticfiles import StaticFiles
 from starlette.templating import Jinja2Templates
 
+from airflow.api_fastapi.core_api.middleware import FlaskExceptionsMiddleware
 from airflow.exceptions import AirflowException
 from airflow.settings import AIRFLOW_PATH
 from airflow.www.extensions.init_dagbag import get_dag_bag
@@ -165,3 +166,7 @@ def init_error_handlers(app: FastAPI) -> None:
     # register database error handlers
     for handler in DatabaseErrorHandlers:
         app.add_exception_handler(handler.exception_cls, 
handler.exception_handler)
+
+
+def init_middlewares(app: FastAPI) -> None:
+    app.add_middleware(FlaskExceptionsMiddleware)
diff --git a/airflow/api_fastapi/core_api/middleware.py 
b/airflow/api_fastapi/core_api/middleware.py
new file mode 100644
index 00000000000..e88c9acc543
--- /dev/null
+++ b/airflow/api_fastapi/core_api/middleware.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from fastapi import HTTPException, Request
+from starlette.middleware.base import BaseHTTPMiddleware
+
+
+# Custom Middleware Class
+class FlaskExceptionsMiddleware(BaseHTTPMiddleware):
+    """Middleware that converts exceptions thrown in the Flask application to 
Fastapi exceptions."""
+
+    async def dispatch(self, request: Request, call_next):
+        response = await call_next(request)
+
+        # Check if the WSGI response contains an error
+        if response.status_code >= 400 and response.media_type == 
"application/json":
+            body = await response.json()
+            if "error" in body:
+                # Transform the WSGI app's exception into a FastAPI 
HTTPException
+                raise HTTPException(
+                    status_code=response.status_code,
+                    detail=body["error"],
+                )
+        return response
diff --git a/airflow/api_fastapi/core_api/security.py 
b/airflow/api_fastapi/core_api/security.py
index 7aaee4d0fae..baa852a4c54 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Callable
 from fastapi import Depends, HTTPException
 from fastapi.security import OAuth2PasswordBearer
 from jwt import InvalidTokenError
+from starlette import status
 
 from airflow.api_fastapi.app import get_auth_manager
 from airflow.auth.managers.models.base_user import BaseUser
@@ -50,7 +51,7 @@ def get_user(token_str: Annotated[str, 
Depends(oauth2_scheme)]) -> BaseUser:
         payload: dict[str, Any] = signer.verify_token(token_str)
         return get_auth_manager().deserialize_user(payload)
     except InvalidTokenError:
-        raise HTTPException(403, "Forbidden")
+        raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
 
 
 def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity 
| None = None) -> Callable:
@@ -75,4 +76,4 @@ def _requires_access(
     is_authorized_callback: Callable[[], bool],
 ) -> None:
     if not is_authorized_callback():
-        raise HTTPException(403, "Forbidden")
+        raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
diff --git 
a/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py 
b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
index ef2e2ab9b4b..feaf38e8fd0 100644
--- a/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
+++ b/providers/src/airflow/providers/fab/www/api_connexion/exceptions.py
@@ -17,16 +17,12 @@
 from __future__ import annotations
 
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Any
+from typing import Any
 
-import werkzeug
-from connexion import FlaskApi, ProblemException, problem
+from connexion import ProblemException
 
 from airflow.utils.docs import get_docs_url
 
-if TYPE_CHECKING:
-    import flask
-
 doc_link = get_docs_url("stable-rest-api-ref.html")
 
 EXCEPTIONS_LINK_MAP = {
@@ -40,39 +36,6 @@ EXCEPTIONS_LINK_MAP = {
 }
 
 
-def common_error_handler(exception: BaseException) -> flask.Response:
-    """Use to capture connexion exceptions and add link to the type field."""
-    if isinstance(exception, ProblemException):
-        link = EXCEPTIONS_LINK_MAP.get(exception.status)
-        if link:
-            response = problem(
-                status=exception.status,
-                title=exception.title,
-                detail=exception.detail,
-                type=link,
-                instance=exception.instance,
-                headers=exception.headers,
-                ext=exception.ext,
-            )
-        else:
-            response = problem(
-                status=exception.status,
-                title=exception.title,
-                detail=exception.detail,
-                type=exception.type,
-                instance=exception.instance,
-                headers=exception.headers,
-                ext=exception.ext,
-            )
-    else:
-        if not isinstance(exception, werkzeug.exceptions.HTTPException):
-            exception = werkzeug.exceptions.InternalServerError()
-
-        response = problem(title=exception.name, detail=exception.description, 
status=exception.code)
-
-    return FlaskApi.get_response(response)
-
-
 class NotFound(ProblemException):
     """Raise when the object cannot be found."""
 
diff --git a/providers/src/airflow/providers/fab/www/extensions/init_views.py 
b/providers/src/airflow/providers/fab/www/extensions/init_views.py
index 382bcaf9ca7..588276735fc 100644
--- a/providers/src/airflow/providers/fab/www/extensions/init_views.py
+++ b/providers/src/airflow/providers/fab/www/extensions/init_views.py
@@ -23,6 +23,15 @@ from typing import TYPE_CHECKING
 from connexion import Resolver
 from connexion.decorators.validation import RequestBodyValidator
 from connexion.exceptions import BadRequestProblem
+from flask import jsonify
+from starlette import status
+
+from airflow.providers.fab.www.api_connexion.exceptions import (
+    BadRequest,
+    NotFound,
+    PermissionDenied,
+    Unauthenticated,
+)
 
 if TYPE_CHECKING:
     from flask import Flask
@@ -114,7 +123,26 @@ def init_plugins(app):
 
 def init_error_handlers(app: Flask):
     """Add custom errors handlers."""
-    from airflow.providers.fab.www import views
 
-    app.register_error_handler(500, views.show_traceback)
-    app.register_error_handler(404, views.not_found)
+    def handle_bad_request(error):
+        response = {"error": "Bad request"}
+        return jsonify(response), status.HTTP_400_BAD_REQUEST
+
+    def handle_not_found(error):
+        response = {"error": "Not found"}
+        return jsonify(response), status.HTTP_404_NOT_FOUND
+
+    def handle_unauthenticated(error):
+        response = {"error": "User is not authenticated"}
+        return jsonify(response), status.HTTP_401_UNAUTHORIZED
+
+    def handle_denied(error):
+        response = {"error": "Access is denied"}
+        return jsonify(response), status.HTTP_403_FORBIDDEN
+
+    app.register_error_handler(404, handle_not_found)
+
+    app.register_error_handler(BadRequest, handle_bad_request)
+    app.register_error_handler(NotFound, handle_not_found)
+    app.register_error_handler(Unauthenticated, handle_unauthenticated)
+    app.register_error_handler(PermissionDenied, handle_denied)
diff --git a/providers/src/airflow/providers/fab/www/views.py 
b/providers/src/airflow/providers/fab/www/views.py
index 48bf0bfddff..43ac276897e 100644
--- a/providers/src/airflow/providers/fab/www/views.py
+++ b/providers/src/airflow/providers/fab/www/views.py
@@ -30,19 +30,6 @@ from airflow.utils.net import get_hostname
 from airflow.version import version
 
 
-def not_found(error):
-    """Show Not Found on screen for any error in the Webserver."""
-    return (
-        render_template(
-            "airflow/error.html",
-            hostname=get_hostname() if conf.getboolean("webserver", 
"EXPOSE_HOSTNAME") else "",
-            status_code=404,
-            error_message="Page cannot be found.",
-        ),
-        404,
-    )
-
-
 def show_traceback(error):
     """Show Traceback for a given error."""
     is_logged_in = get_auth_manager().is_logged_in()

Reply via email to