This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2bc6f09ea38 AIP-84 | Add Auth for Import Error (#47270)
2bc6f09ea38 is described below
commit 2bc6f09ea3865cb0c0479efa88cab5256e96f7e3
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Mon Mar 17 21:37:54 2025 +0800
AIP-84 | Add Auth for Import Error (#47270)
* AIP-84 | Add Auth for Import Error
* fix: remove outdated requires_access_view
* Add permitted dags with import_error
* Add test for import error
* Refactor import_error
- Remove requires_access_dag depends
- Refactor get_file_dag_ids helper
- Rename get_permitted_dag_ids to get_authorized_dag_ids
- Refactor get_import_errors
- Early return if the user has access to all DAGs
- Add subquery to get corresponding file_dag_ids for each fileloc in
single query
* Refactor test_import_error
- Make test more readable by adding separate comment
- Add set auth_manager attribute helper
- Refactor test setup
* Remove get_file_dag_ids helper
* Remove 403 in api doc, early return for get_import_error router
* feat(import_error): rewrite groupby logic as array_agg is only supported
in pg
* test(api_fastapi): refactor test cases
* style: fix type checking
---------
Co-authored-by: Wei Lee <[email protected]>
---
.../api_fastapi/core_api/openapi/v1-generated.yaml | 4 +
.../core_api/routes/public/import_error.py | 98 ++++++++-
.../core_api/routes/public/test_import_error.py | 240 ++++++++++++++++-----
3 files changed, 290 insertions(+), 52 deletions(-)
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index ec8ac63515a..35c62fb7c47 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -3935,6 +3935,8 @@ paths:
summary: Get Import Error
description: Get an import error.
operationId: get_import_error
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: import_error_id
in: path
@@ -3980,6 +3982,8 @@ paths:
summary: Get Import Errors
description: Get all import errors.
operationId: get_import_errors
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py
b/airflow/api_fastapi/core_api/routes/public/import_error.py
index 01caf9048e2..4beb0ea2cd4 100644
--- a/airflow/api_fastapi/core_api/routes/public/import_error.py
+++ b/airflow/api_fastapi/core_api/routes/public/import_error.py
@@ -16,11 +16,19 @@
# under the License.
from __future__ import annotations
+from collections.abc import Iterable, Sequence
+from itertools import groupby
+from operator import itemgetter
from typing import Annotated
from fastapi import Depends, HTTPException, status
from sqlalchemy import select
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.models.batch_apis import
IsAuthorizedDagRequest
+from airflow.api_fastapi.auth.managers.models.resource_details import (
+ DagDetails,
+)
from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
@@ -36,18 +44,29 @@ from airflow.api_fastapi.core_api.datamodels.import_error
import (
ImportErrorResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import (
+ AccessView,
+ GetUserDep,
+ requires_access_view,
+)
+from airflow.models import DagModel
from airflow.models.errors import ParseImportError
+REDACTED_STACKTRACE = "REDACTED - you do not have read permission on all DAGs
in the file"
import_error_router = AirflowRouter(tags=["Import Error"],
prefix="/importErrors")
@import_error_router.get(
"/{import_error_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+ dependencies=[
+ Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
+ ],
)
def get_import_error(
import_error_id: int,
session: SessionDep,
+ user: GetUserDep,
) -> ImportErrorResponse:
"""Get an import error."""
error = session.scalar(select(ParseImportError).where(ParseImportError.id
== import_error_id))
@@ -56,12 +75,37 @@ def get_import_error(
status.HTTP_404_NOT_FOUND,
f"The ImportError with import_error_id: `{import_error_id}` was
not found",
)
+ session.expunge(error)
+
+ auth_manager = get_auth_manager()
+ can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
+ if can_read_all_dags:
+ # Early return if the user has access to all DAGs
+ return error
+
+ readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
+ # We need file_dag_ids as a set for intersection, issubset operations
+ file_dag_ids = set(
+ session.scalars(select(DagModel.dag_id).where(DagModel.fileloc ==
error.filename)).all()
+ )
+ # Can the user read any DAGs in the file?
+ if not readable_dag_ids.intersection(file_dag_ids):
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ "You do not have read permission on any of the DAGs in the file",
+ )
+ # Check if user has read access to all the DAGs defined in the file
+ if not file_dag_ids.issubset(readable_dag_ids):
+ error.stacktrace = REDACTED_STACKTRACE
return error
@import_error_router.get(
"",
+ dependencies=[
+ Depends(requires_access_view(AccessView.IMPORT_ERRORS)),
+ ],
)
def get_import_errors(
limit: QueryLimit,
@@ -83,6 +127,7 @@ def get_import_errors(
),
],
session: SessionDep,
+ user: GetUserDep,
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
@@ -92,7 +137,58 @@ def get_import_errors(
limit=limit,
session=session,
)
- import_errors = session.scalars(import_errors_select)
+
+ auth_manager = get_auth_manager()
+ can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
+ if can_read_all_dags:
+ # Early return if the user has access to all DAGs
+ import_errors = session.scalars(import_errors_select).all()
+ return ImportErrorCollectionResponse(
+ import_errors=import_errors,
+ total_entries=total_entries,
+ )
+
+ # if the user doesn't have access to all DAGs, only display errors from
visible DAGs
+ readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET",
user=user)
+ # Build a cte that fetches dag_ids for each file location
+ visiable_files_cte = (
+ select(DagModel.fileloc,
DagModel.dag_id).where(DagModel.dag_id.in_(readable_dag_ids)).cte()
+ )
+
+ # Prepare the import errors query by joining with the cte.
+ # Each returned row will be a tuple: (ParseImportError, dag_id)
+ import_errors_stmt = (
+ select(ParseImportError, visiable_files_cte.c.dag_id)
+ .join(visiable_files_cte, ParseImportError.filename ==
visiable_files_cte.c.fileloc)
+ .order_by(ParseImportError.id)
+ )
+
+ # Paginate the import errors query
+ import_errors_select, total_entries = paginated_select(
+ statement=import_errors_stmt,
+ order_by=order_by,
+ offset=offset,
+ limit=limit,
+ session=session,
+ )
+ import_errors_result: Iterable[tuple[ParseImportError, Iterable[str]]] =
groupby(
+ session.execute(import_errors_select), itemgetter(0)
+ )
+
+ import_errors = []
+ for import_error, file_dag_ids in import_errors_result:
+ # Check if user has read access to all the DAGs defined in the file
+ requests: Sequence[IsAuthorizedDagRequest] = [
+ {
+ "method": "GET",
+ "details": DagDetails(id=dag_id),
+ }
+ for dag_id in file_dag_ids
+ ]
+ if not auth_manager.batch_is_authorized_dag(requests, user=user):
+ session.expunge(import_error)
+ import_error.stacktrace = REDACTED_STACKTRACE
+ import_errors.append(import_error)
return ImportErrorCollectionResponse(
import_errors=import_errors,
diff --git a/tests/api_fastapi/core_api/routes/public/test_import_error.py
b/tests/api_fastapi/core_api/routes/public/test_import_error.py
index 38965ab7184..84d7871585b 100644
--- a/tests/api_fastapi/core_api/routes/public/test_import_error.py
+++ b/tests/api_fastapi/core_api/routes/public/test_import_error.py
@@ -17,15 +17,21 @@
from __future__ import annotations
from datetime import datetime, timezone
+from typing import TYPE_CHECKING
+from unittest import mock
import pytest
+from airflow.models import DagModel
from airflow.models.errors import ParseImportError
-from airflow.utils.session import provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
-from tests_common.test_utils.db import clear_db_import_errors
+from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors
from tests_common.test_utils.format_datetime import
from_datetime_to_zulu_without_ms
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
pytestmark = pytest.mark.db_test
FILENAME1 = "test_filename1.py"
@@ -42,95 +48,171 @@ IMPORT_ERROR_NON_EXISTED_KEY = "non_existed_key"
BUNDLE_NAME = "dag_maker"
-class TestImportErrorEndpoint:
- """Common class for /public/importErrors related unit tests."""
[email protected](scope="class")
+@provide_session
+def permitted_dag_model(session: Session = NEW_SESSION) -> DagModel:
+ dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id1", is_paused=False)
+ session.add(dag_model)
+ session.commit()
+ return dag_model
- @staticmethod
- def _clear_db():
- clear_db_import_errors()
- @pytest.fixture(autouse=True)
- @provide_session
- def setup(self, session=None) -> dict[str, ParseImportError]:
- """
- Setup method which is run before every test.
- """
- self._clear_db()
- import_error1 = ParseImportError(
- bundle_name=BUNDLE_NAME,
- filename=FILENAME1,
- stacktrace=STACKTRACE1,
- timestamp=TIMESTAMP1,
- )
- import_error2 = ParseImportError(
[email protected](scope="class")
+@provide_session
+def not_permitted_dag_model(session: Session = NEW_SESSION) -> DagModel:
+ dag_model = DagModel(fileloc=FILENAME1, dag_id="dag_id4", is_paused=False)
+ session.add(dag_model)
+ session.commit()
+ return dag_model
+
+
[email protected](scope="class", autouse=True)
+def clear_db():
+ clear_db_import_errors()
+ clear_db_dags()
+
+ yield
+
+ clear_db_import_errors()
+ clear_db_dags()
+
+
[email protected](autouse=True, scope="class")
+@provide_session
+def import_errors(session: Session = NEW_SESSION) -> list[ParseImportError]:
+ _import_errors = [
+ ParseImportError(
bundle_name=BUNDLE_NAME,
- filename=FILENAME2,
- stacktrace=STACKTRACE2,
- timestamp=TIMESTAMP2,
+ filename=filename,
+ stacktrace=stacktrace,
+ timestamp=timestamp,
)
- import_error3 = ParseImportError(
- bundle_name=BUNDLE_NAME,
- filename=FILENAME3,
- stacktrace=STACKTRACE3,
- timestamp=TIMESTAMP3,
+ for filename, stacktrace, timestamp in zip(
+ (FILENAME1, FILENAME2, FILENAME3),
+ (STACKTRACE1, STACKTRACE2, STACKTRACE3),
+ (TIMESTAMP1, TIMESTAMP2, TIMESTAMP3),
)
- session.add_all([import_error1, import_error2, import_error3])
- session.commit()
- return {FILENAME1: import_error1, FILENAME2: import_error2, FILENAME3:
import_error3}
+ ]
+
+ session.add_all(_import_errors)
+ return _import_errors
+
- def teardown_method(self) -> None:
- self._clear_db()
+def set_mock_auth_manager__is_authorized_dag(
+ mock_auth_manager: mock.Mock, is_authorized_dag_return_value: bool = False
+) -> mock.Mock:
+ mock_is_authorized_dag = mock_auth_manager.return_value.is_authorized_dag
+ mock_is_authorized_dag.return_value = is_authorized_dag_return_value
+ return mock_is_authorized_dag
-class TestGetImportError(TestImportErrorEndpoint):
+def set_mock_auth_manager__get_authorized_dag_ids(
+ mock_auth_manager: mock.Mock, get_authorized_dag_ids_return_value:
set[str] | None = None
+) -> mock.Mock:
+ if get_authorized_dag_ids_return_value is None:
+ get_authorized_dag_ids_return_value = set()
+ mock_get_authorized_dag_ids =
mock_auth_manager.return_value.get_authorized_dag_ids
+ mock_get_authorized_dag_ids.return_value =
get_authorized_dag_ids_return_value
+ return mock_get_authorized_dag_ids
+
+
+def set_mock_auth_manager__batch_is_authorized_dag(
+ mock_auth_manager: mock.Mock, batch_is_authorized_dag_return_value: bool =
False
+) -> mock.Mock:
+ mock_batch_is_authorized_dag =
mock_auth_manager.return_value.batch_is_authorized_dag
+ mock_batch_is_authorized_dag.return_value =
batch_is_authorized_dag_return_value
+ return mock_batch_is_authorized_dag
+
+
+class TestGetImportError:
@pytest.mark.parametrize(
- "import_error_key, expected_status_code, expected_body",
+ "prepared_import_error_idx, expected_status_code, expected_body",
[
(
- FILENAME1,
+ 0,
200,
{
- "import_error_id": 1,
- "timestamp": TIMESTAMP1,
+ "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
"filename": FILENAME1,
"stack_trace": STACKTRACE1,
"bundle_name": BUNDLE_NAME,
},
),
(
- FILENAME2,
+ 1,
200,
{
- "import_error_id": 2,
- "timestamp": TIMESTAMP2,
+ "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP2),
"filename": FILENAME2,
"stack_trace": STACKTRACE2,
"bundle_name": BUNDLE_NAME,
},
),
- (IMPORT_ERROR_NON_EXISTED_KEY, 404, {}),
+ (None, 404, {}),
],
)
def test_get_import_error(
- self, test_client, setup, import_error_key, expected_status_code,
expected_body
+ self, prepared_import_error_idx, expected_status_code, expected_body,
test_client, import_errors
):
- import_error: ParseImportError | None = setup.get(import_error_key)
+ import_error: ParseImportError | None = (
+ import_errors[prepared_import_error_idx] if
prepared_import_error_idx is not None else None
+ )
import_error_id = import_error.id if import_error else
IMPORT_ERROR_NON_EXISTED_ID
response = test_client.get(f"/public/importErrors/{import_error_id}")
assert response.status_code == expected_status_code
if expected_status_code != 200:
return
- expected_json = {
+
+ expected_body.update({"import_error_id": import_error_id})
+ assert response.json() == expected_body
+
+ def test_should_raises_401_unauthenticated(self,
unauthenticated_test_client, import_errors):
+ import_error_id = import_errors[0].id
+ response =
unauthenticated_test_client.get(f"/public/importErrors/{import_error_id}")
+ assert response.status_code == 401
+
+ def test_should_raises_403_unauthorized(self, unauthorized_test_client,
import_errors):
+ import_error_id = import_errors[0].id
+ response =
unauthorized_test_client.get(f"/public/importErrors/{import_error_id}")
+ assert response.status_code == 403
+
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+ def
test_should_raises_403_unauthorized__user_can_not_read_any_dags_in_file(
+ self, mock_get_auth_manager, test_client, import_errors
+ ):
+ import_error_id = import_errors[0].id
+ # Mock auth_manager
+ mock_is_authorized_dag =
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+ mock_get_authorized_dag_ids =
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
+ # Act
+ response = test_client.get(f"/public/importErrors/{import_error_id}")
+ # Assert
+ mock_is_authorized_dag.assert_called_once_with(method="GET",
user=mock.ANY)
+ mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
+ assert response.status_code == 403
+ assert response.json() == {"detail": "You do not have read permission
on any of the DAGs in the file"}
+
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+ def
test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_file(
+ self, mock_get_auth_manager, test_client, permitted_dag_model,
not_permitted_dag_model, import_errors
+ ):
+ import_error_id = import_errors[0].id
+ set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+ set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager,
{permitted_dag_model.dag_id})
+ # Act
+ response = test_client.get(f"/public/importErrors/{import_error_id}")
+ # Assert
+ assert response.status_code == 200
+ assert response.json() == {
"import_error_id": import_error_id,
- "timestamp":
from_datetime_to_zulu_without_ms(expected_body["timestamp"]),
- "filename": expected_body["filename"],
- "stack_trace": expected_body["stack_trace"],
+ "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
+ "filename": FILENAME1,
+ "stack_trace": "REDACTED - you do not have read permission on all
DAGs in the file",
"bundle_name": BUNDLE_NAME,
}
- assert response.json() == expected_json
-class TestGetImportErrors(TestImportErrorEndpoint):
+class TestGetImportErrors:
@pytest.mark.parametrize(
"query_params, expected_status_code, expected_total_entries,
expected_filenames",
[
@@ -225,3 +307,59 @@ class TestGetImportErrors(TestImportErrorEndpoint):
assert [
import_error["filename"] for import_error in
response_json["import_errors"]
] == expected_filenames
+
+ def test_should_raises_401_unauthenticated(self,
unauthenticated_test_client):
+ response = unauthenticated_test_client.get("/public/importErrors")
+ assert response.status_code == 401
+
+ def test_should_raises_403_unauthorized(self, unauthorized_test_client):
+ response = unauthorized_test_client.get("/public/importErrors")
+ assert response.status_code == 403
+
+ @pytest.mark.parametrize(
+ "batch_is_authorized_dag_return_value, expected_stack_trace",
+ [
+ pytest.param(True, STACKTRACE1,
id="user_has_read_access_to_all_dags_in_current_file"),
+ pytest.param(
+ False,
+ "REDACTED - you do not have read permission on all DAGs in the
file",
+
id="user_does_not_have_read_access_to_all_dags_in_current_file",
+ ),
+ ],
+ )
+ @pytest.mark.usefixtures("permitted_dag_model")
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
+ def test_user_can_not_read_all_dags_in_file(
+ self,
+ mock_get_auth_manager,
+ test_client,
+ batch_is_authorized_dag_return_value,
+ expected_stack_trace,
+ permitted_dag_model,
+ import_errors,
+ ):
+ set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
+ mock_get_authorized_dag_ids =
set_mock_auth_manager__get_authorized_dag_ids(
+ mock_get_auth_manager, {permitted_dag_model.dag_id}
+ )
+ set_mock_auth_manager__batch_is_authorized_dag(
+ mock_get_auth_manager, batch_is_authorized_dag_return_value
+ )
+ # Act
+ response = test_client.get("/public/importErrors")
+ # Assert
+ mock_get_authorized_dag_ids.assert_called_once_with(method="GET",
user=mock.ANY)
+ assert response.status_code == 200
+ response_json = response.json()
+ assert response_json == {
+ "total_entries": 1,
+ "import_errors": [
+ {
+ "import_error_id": import_errors[0].id,
+ "timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP1),
+ "filename": FILENAME1,
+ "stack_trace": expected_stack_trace,
+ "bundle_name": BUNDLE_NAME,
+ }
+ ],
+ }