This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2bc6f09ea38 AIP-84 | Add Auth for Import Error (#47270)
2bc6f09ea38 is described below

commit 2bc6f09ea3865cb0c0479efa88cab5256e96f7e3
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Mon Mar 17 21:37:54 2025 +0800

    AIP-84 | Add Auth for Import Error (#47270)
    
    * AIP-84 | Add Auth for Import Error
    
    * fix: remove outdated requires_access_view
    
    * Add permitted dags with import_error
    
    * Add test for import error
    
    * Refactor import_error
    - Remove requires_access_dag depends
    - Refactor get_file_dag_ids helper
    - Rename get_permitted_dag_ids to get_authorized_dag_ids
    - Refactor get_import_errors
        - Early return if the user has access to all DAGs
        - Add subquery to get corresponding file_dag_ids for each fileloc in 
single query
    
    * Refactor test_import_error
    - Make test more readable by adding separate comment
    - Add set auth_manager attribute helper
    - Refactor test setup
    
    * Remove get_file_dag_ids helper
    
    * Remove 403 in api doc, early return for get_import_error router
    
    * feat(import_error): rewrite groupby logic as array_agg is only supported 
in pg
    
    * test(api_fastapi): refactor test cases
    
    * style: fix type checking
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../api_fastapi/core_api/openapi/v1-generated.yaml |   4 +
 .../core_api/routes/public/import_error.py         |  98 ++++++++-
 .../core_api/routes/public/test_import_error.py    | 240 ++++++++++++++++-----
 3 files changed, 290 insertions(+), 52 deletions(-)

diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index ec8ac63515a..35c62fb7c47 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -3935,6 +3935,8 @@ paths:
       summary: Get Import Error
       description: Get an import error.
       operationId: get_import_error
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: import_error_id
         in: path
@@ -3980,6 +3982,8 @@ paths:
       summary: Get Import Errors
       description: Get all import errors.
       operationId: get_import_errors
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: limit
         in: query
diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py 
b/airflow/api_fastapi/core_api/routes/public/import_error.py
index 01caf9048e2..4beb0ea2cd4 100644
--- a/airflow/api_fastapi/core_api/routes/public/import_error.py
+++ b/airflow/api_fastapi/core_api/routes/public/import_error.py
@@ -16,11 +16,19 @@
 # under the License.
 from __future__ import annotations
 
+from collections.abc import Iterable, Sequence
+from itertools import groupby
+from operator import itemgetter
 from typing import Annotated
 
 from fastapi import Depends, HTTPException, status
 from sqlalchemy import select
 
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.models.batch_apis import 
IsAuthorizedDagRequest
+from airflow.api_fastapi.auth.managers.models.resource_details import (
+    DagDetails,
+)
 from airflow.api_fastapi.common.db.common import (
     SessionDep,
     paginated_select,
@@ -36,18 +44,29 @@ from airflow.api_fastapi.core_api.datamodels.import_error 
import (
     ImportErrorResponse,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import (
+    AccessView,
+    GetUserDep,
+    requires_access_view,
+)
+from airflow.models import DagModel
 from airflow.models.errors import ParseImportError
 
+REDACTED_STACKTRACE = "REDACTED - you do not have read permission on all DAGs 
in the file"
 import_error_router = AirflowRouter(tags=["Import Error"], 
prefix="/importErrors")
 
 
 @import_error_router.get(
     "/{import_error_id}",
     responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+    dependencies=[
+        Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
+    ],
 )
 def get_import_error(
     import_error_id: int,
     session: SessionDep,
+    user: GetUserDep,
 ) -> ImportErrorResponse:
     """Get an import error."""
     error = session.scalar(select(ParseImportError).where(ParseImportError.id 
== import_error_id))
@@ -56,12 +75,37 @@ def get_import_error(
             status.HTTP_404_NOT_FOUND,
             f"The ImportError with import_error_id: `{import_error_id}` was 
not found",
         )
+    session.expunge(error)
+
+    auth_manager = get_auth_manager()
+    can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
+    if can_read_all_dags:
+        # Early return if the user has access to all DAGs
+        return error
+
+    readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
+    # We need file_dag_ids as a set for intersection, issubset operations
+    file_dag_ids = set(
+        session.scalars(select(DagModel.dag_id).where(DagModel.fileloc == 
error.filename)).all()
+    )
+    # Can the user read any DAGs in the file?
+    if not readable_dag_ids.intersection(file_dag_ids):
+        raise HTTPException(
+            status.HTTP_403_FORBIDDEN,
+            "You do not have read permission on any of the DAGs in the file",
+        )
 
+    # Check if user has read access to all the DAGs defined in the file
+    if not file_dag_ids.issubset(readable_dag_ids):
+        error.stacktrace = REDACTED_STACKTRACE
     return error
 
 
 @import_error_router.get(
     "",
+    dependencies=[
+        Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
+    ],
 )
 def get_import_errors(
     limit: QueryLimit,
@@ -83,6 +127,7 @@ def get_import_errors(
         ),
     ],
     session: SessionDep,
+    user: GetUserDep,
 ) -> ImportErrorCollectionResponse:
     """Get all import errors."""
     import_errors_select, total_entries = paginated_select(
@@ -92,7 +137,58 @@ def get_import_errors(
         limit=limit,
         session=session,
     )
-    import_errors = session.scalars(import_errors_select)
+
+    auth_manager = get_auth_manager()
+    can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
+    if can_read_all_dags:
+        # Early return if the user has access to all DAGs
+        import_errors = session.scalars(import_errors_select).all()
+        return ImportErrorCollectionResponse(
+            import_errors=import_errors,
+            total_entries=total_entries,
+        )
+
+    # if the user doesn't have access to all DAGs, only display errors from 
visible DAGs
+    readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET", 
user=user)
+    # Build a cte that fetches dag_ids for each file location
+    visiable_files_cte = (
+        select(DagModel.fileloc, 
DagModel.dag_id).where(DagModel.dag_id.in_(readable_dag_ids)).cte()
+    )
+
+    # Prepare the import errors query by joining with the cte.
+    # Each returned row will be a tuple: (ParseImportError, dag_id)
+    import_errors_stmt = (
+        select(ParseImportError, visiable_files_cte.c.dag_id)
+        .join(visiable_files_cte, ParseImportError.filename == 
visiable_files_cte.c.fileloc)
+        .order_by(ParseImportError.id)
+    )
+
+    # Paginate the import errors query
+    import_errors_select, total_entries = paginated_select(
+        statement=import_errors_stmt,
+        order_by=order_by,
+        offset=offset,
+        limit=limit,
+        session=session,
+    )
+    import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] = 
groupby(
+        session.execute(import_errors_select), itemgetter(0)
+    )
+
+    import_errors = []
+    for import_error, file_dag_ids in import_errors_result:
+        # Check if user has read access to all the DAGs defined in the file
+        requests: Sequence[IsAuthorizedDagRequest] = [
+            {
+                "method": "GET",
+                "details": DagDetails(id=dag_id),
+            }
+            for dag_id in file_dag_ids
+        ]
+        if not auth_manager.batch_is_authorized_dag(requests, user=user):
+            session.expunge(import_error)
+            import_error.stacktrace = REDACTED_STACKTRACE
+        import_errors.append(import_error)
 
     return ImportErrorCollectionResponse(
         import_errors=import_errors,
diff --git a/tests/api_fastapi/core_api/routes/public/test_import_error.py 
b/tests/api_fastapi/core_api/routes/public/test_import_error.py
index 38965ab7184..84d7871585b 100644
--- a/tests/api_fastapi/core_api/routes/public/test_import_error.py
+++ b/tests/api_fastapi/core_api/routes/public/test_import_error.py
@@ -17,15 +17,21 @@
 from __future__ import annotations
 
 from datetime import datetime, timezone
+from typing import TYPE_CHECKING
+from unittest import mock
 
 import pytest
 
+from airflow.models import DagModel
 from airflow.models.errors import ParseImportError
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 
-from tests_common.test_utils.db import clear_db_import_errors
+from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors
 from tests_common.test_utils.format_datetime import 
from_datetime_to_zulu_without_ms
 
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
 pytestmark = pytest.mark.db_test
 
 FILENAME1 = "test_filename1.py"
@@ -42,95 +48,171 @@ IMPORT_ERROR_NON_EXISTED_KEY = "non_existed_key"
 BUNDLE_NAME = "dag_maker"
 
 
-class TestImportErrorEndpoint:
-    """Common class for /public/importErrors related unit tests."""
[email protected](scope="class")
+@provide_session
+def permitted_dag_model(session: Session = NEW_SESSION) -> DagModel:
+    dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id1", is_paused=False)
+    session.add(dag_model)
+    session.commit()
+    return dag_model
 
-    @staticmethod
-    def _clear_db():
-        clear_db_import_errors()
 
-    @pytest.fixture(autouse=True)
-    @provide_session
-    def setup(self, session=None) -> dict[str, ParseImportError]:
-        """
-        Setup method which is run before every test.
-        """
-        self._clear_db()
-        import_error1 = ParseImportError(
-            bundle_name=BUNDLE_NAME,
-            filename=FILENAME1,
-            stacktrace=STACKTRACE1,
-            timestamp=TIMESTAMP1,
-        )
-        import_error2 = ParseImportError(
[email protected](scope="class")
+@provide_session
+def not_permitted_dag_model(session: Session = NEW_SESSION) -> DagModel:
+    dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id4", is_paused=False)
+    session.add(dag_model)
+    session.commit()
+    return dag_model
+
+
[email protected](scope="class", autouse=True)
+def clear_db():
+    clear_db_import_errors()
+    clear_db_dags()
+
+    yield
+
+    clear_db_import_errors()
+    clear_db_dags()
+
+
[email protected](autouse=True, scope="class")
+@provide_session
+def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]:
+    _import_errors = [
+        ParseImportError(
             bundle_name=BUNDLE_NAME,
-            filename=FILENAME2,
-            stacktrace=STACKTRACE2,
-            timestamp=TIMESTAMP2,
+            filename=filename,
+            stacktrace=stacktrace,
+            timestamp=timestamp,
         )
-        import_error3 = ParseImportError(
-            bundle_name=BUNDLE_NAME,
-            filename=FILENAME3,
-            stacktrace=STACKTRACE3,
-            timestamp=TIMESTAMP3,
+        for filename, stacktrace, timestamp in zip(
+            (FILENAME1, FILENAME2, FILENAME3),
+            (STACKTRACE1, STACKTRACE2, STACKTRACE3),
+            (TIMESTAMP1, TIMESTAMP2, TIMESTAMP3),
         )
-        session.add_all([import_error1, import_error2, import_error3])
-        session.commit()
-        return {FILENAME1: import_error1, FILENAME2: import_error2, FILENAME3: 
import_error3}
+    ]
+
+    session.add_all(_import_errors)
+    return _import_errors
+
 
-    def teardown_method(self) -> None:
-        self._clear_db()
+def set_mock_auth_manager__is_authorized_dag(
+    mock_auth_manager: mock.Mock, is_authorized_dag_return_value: bool = False
+) -> mock.Mock:
+    mock_is_authorized_dag = mock_auth_manager.return_value.is_authorized_dag
+    mock_is_authorized_dag.return_value = is_authorized_dag_return_value
+    return mock_is_authorized_dag
 
 
-class TestGetImportError(TestImportErrorEndpoint):
+def set_mock_auth_manager__get_authorized_dag_ids(
+    mock_auth_manager: mock.Mock, get_authorized_dag_ids_return_value: 
set[str] | None = None
+) -> mock.Mock:
+    if get_authorized_dag_ids_return_value is None:
+        get_authorized_dag_ids_return_value = set()
+    mock_get_authorized_dag_ids = 
mock_auth_manager.return_value.get_authorized_dag_ids
+    mock_get_authorized_dag_ids.return_value = 
get_authorized_dag_ids_return_value
+    return mock_get_authorized_dag_ids
+
+
+def set_mock_auth_manager__batch_is_authorized_dag(
+    mock_auth_manager: mock.Mock, batch_is_authorized_dag_return_value: bool = 
False
+) -> mock.Mock:
+    mock_batch_is_authorized_dag = 
mock_auth_manager.return_value.batch_is_authorized_dag
+    mock_batch_is_authorized_dag.return_value = 
batch_is_authorized_dag_return_value
+    return mock_batch_is_authorized_dag
+
+
+class TestGetImportError:
     @pytest.mark.parametrize(
-        "import_error_key, expected_status_code, expected_body",
+        "prepared_import_error_idx, expected_status_code, expected_body",
         [
             (
-                FILENAME1,
+                0,
                 200,
                 {
-                    "import_error_id": 1,
-                    "timestamp": TIMESTAMP1,
+                    "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
                     "filename": FILENAME1,
                     "stack_trace": STACKTRACE1,
                     "bundle_name": BUNDLE_NAME,
                 },
             ),
             (
-                FILENAME2,
+                1,
                 200,
                 {
-                    "import_error_id": 2,
-                    "timestamp": TIMESTAMP2,
+                    "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP2),
                     "filename": FILENAME2,
                     "stack_trace": STACKTRACE2,
                     "bundle_name": BUNDLE_NAME,
                 },
             ),
-            (IMPORT_ERROR_NON_EXISTED_KEY, 404, {}),
+            (None, 404, {}),
         ],
     )
     def test_get_import_error(
-        self, test_client, setup, import_error_key, expected_status_code, 
expected_body
+        self, prepared_import_error_idx, expected_status_code, expected_body, 
test_client, import_errors
     ):
-        import_error: ParseImportError | None = setup.get(import_error_key)
+        import_error: ParseImportError | None = (
+            import_errors[prepared_import_error_idx] if 
prepared_import_error_idx is not None else None
+        )
         import_error_id = import_error.id if import_error else 
IMPORT_ERROR_NON_EXISTED_ID
         response = test_client.get(f"/public/importErrors/{import_error_id}")
         assert response.status_code == expected_status_code
         if expected_status_code != 200:
             return
-        expected_json = {
+
+        expected_body.update({"import_error_id": import_error_id})
+        assert response.json() == expected_body
+
+    def test_should_raises_401_unauthenticated(self, 
unauthenticated_test_client, import_errors):
+        import_error_id = import_errors[0].id
+        response = 
unauthenticated_test_client.get(f"/public/importErrors/{import_error_id}")
+        assert response.status_code == 401
+
+    def test_should_raises_403_unauthorized(self, unauthorized_test_client, 
import_errors):
+        import_error_id = import_errors[0].id
+        response = 
unauthorized_test_client.get(f"/public/importErrors/{import_error_id}")
+        assert response.status_code == 403
+
+    
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+    def 
test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
+        self, mock_get_auth_manager, test_client, import_errors
+    ):
+        import_error_id = import_errors[0].id
+        # Mock auth_manager
+        mock_is_authorized_dag = 
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+        mock_get_authorized_dag_ids = 
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
+        # Act
+        response = test_client.get(f"/public/importErrors/{import_error_id}")
+        # Assert
+        mock_is_authorized_dag.assert_called_once_with(method="GET", 
user=mock.ANY)
+        mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
+        assert response.status_code == 403
+        assert response.json() == {"detail": "You do not have read permission 
on any of the DAGs in the file"}
+
+    
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+    def 
test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_file(
+        self, mock_get_auth_manager, test_client, permitted_dag_model, 
not_permitted_dag_model, import_errors
+    ):
+        import_error_id = import_errors[0].id
+        set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+        set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager, 
{permitted_dag_model.dag_id})
+        # Act
+        response = test_client.get(f"/public/importErrors/{import_error_id}")
+        # Assert
+        assert response.status_code == 200
+        assert response.json() == {
             "import_error_id": import_error_id,
-            "timestamp": 
from_datetime_to_zulu_without_ms(expected_body["timestamp"]),
-            "filename": expected_body["filename"],
-            "stack_trace": expected_body["stack_trace"],
+            "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
+            "filename": FILENAME1,
+            "stack_trace": "REDACTED - you do not have read permission on all 
DAGs in the file",
             "bundle_name": BUNDLE_NAME,
         }
-        assert response.json() == expected_json
 
 
-class TestGetImportErrors(TestImportErrorEndpoint):
+class TestGetImportErrors:
     @pytest.mark.parametrize(
         "query_params, expected_status_code, expected_total_entries, 
expected_filenames",
         [
@@ -225,3 +307,59 @@ class TestGetImportErrors(TestImportErrorEndpoint):
         assert [
             import_error["filename"] for import_error in 
response_json["import_errors"]
         ] == expected_filenames
+
+    def test_should_raises_401_unauthenticated(self, 
unauthenticated_test_client):
+        response = unauthenticated_test_client.get("/public/importErrors")
+        assert response.status_code == 401
+
+    def test_should_raises_403_unauthorized(self, unauthorized_test_client):
+        response = unauthorized_test_client.get("/public/importErrors")
+        assert response.status_code == 403
+
+    @pytest.mark.parametrize(
+        "batch_is_authorized_dag_return_value, expected_stack_trace",
+        [
+            pytest.param(True, STACKTRACE1, 
id="user_has_read_access_to_all_dags_in_current_file"),
+            pytest.param(
+                False,
+                "REDACTED - you do not have read permission on all DAGs in the 
file",
+                
id="user_does_not_have_read_access_to_all_dags_in_current_file",
+            ),
+        ],
+    )
+    @pytest.mark.usefixtures("permitted_dag_model")
+    
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+    def test_user_can_not_read_all_dags_in_file(
+        self,
+        mock_get_auth_manager,
+        test_client,
+        batch_is_authorized_dag_return_value,
+        expected_stack_trace,
+        permitted_dag_model,
+        import_errors,
+    ):
+        set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+        mock_get_authorized_dag_ids = 
set_mock_auth_manager__get_authorized_dag_ids(
+            mock_get_auth_manager, {permitted_dag_model.dag_id}
+        )
+        set_mock_auth_manager__batch_is_authorized_dag(
+            mock_get_auth_manager, batch_is_authorized_dag_return_value
+        )
+        # Act
+        response = test_client.get("/public/importErrors")
+        # Assert
+        mock_get_authorized_dag_ids.assert_called_once_with(method="GET", 
user=mock.ANY)
+        assert response.status_code == 200
+        response_json = response.json()
+        assert response_json == {
+            "total_entries": 1,
+            "import_errors": [
+                {
+                    "import_error_id": import_errors[0].id,
+                    "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
+                    "filename": FILENAME1,
+                    "stack_trace": expected_stack_trace,
+                    "bundle_name": BUNDLE_NAME,
+                }
+            ],
+        }

Reply via email to