This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ac45d92b7d3 Fix permissions check in import error APIs (#60801)
ac45d92b7d3 is described below
commit ac45d92b7d3d39c7c3e1784824e30308bc47b09d
Author: Vincent <[email protected]>
AuthorDate: Wed Jan 21 10:37:43 2026 -0500
Fix permissions check in import error APIs (#60801)
---
.../core_api/routes/public/import_error.py | 24 -------
.../core_api/routes/public/test_import_error.py | 76 ++++++++++++++++------
2 files changed, 57 insertions(+), 43 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
index e9d6eaddb02..97fe524dcf7 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
@@ -79,11 +79,6 @@ def get_import_error(
session.expunge(error)
auth_manager = get_auth_manager()
- can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
- if can_read_all_dags:
- # Early return if the user has access to all DAGs
- return error
-
readable_dag_ids = auth_manager.get_authorized_dag_ids(user=user)
# We need file_dag_ids as a set for intersection, issubset operations
file_dag_ids = set(
@@ -132,26 +127,7 @@ def get_import_errors(
user: GetUserDep,
) -> ImportErrorCollectionResponse:
"""Get all import errors."""
- import_errors_select, total_entries = paginated_select(
- statement=select(ParseImportError),
- filters=[filename_pattern],
- order_by=order_by,
- offset=offset,
- limit=limit,
- session=session,
- )
-
auth_manager = get_auth_manager()
- can_read_all_dags = auth_manager.is_authorized_dag(method="GET", user=user)
- if can_read_all_dags:
- # Early return if the user has access to all DAGs
- import_errors = session.scalars(import_errors_select).all()
- return ImportErrorCollectionResponse(
- import_errors=import_errors,
- total_entries=total_entries,
- )
-
- # if the user doesn't have access to all DAGs, only display errors from
visible DAGs
readable_dag_ids = auth_manager.get_authorized_dag_ids(method="GET",
user=user)
# Build a cte that fetches dag_ids for each file location
visible_files_cte = (
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
index 5c4aa788121..790bdf87a56 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_import_error.py
@@ -66,6 +66,37 @@ def permitted_dag_model(testing_dag_bundle, session: Session
= NEW_SESSION) -> D
return dag_model
[email protected]
+@provide_session
+def permitted_dag_model_all(testing_dag_bundle, session: Session =
NEW_SESSION) -> set[str]:
+ dag_model1 = DagModel(
+ fileloc=FILENAME1,
+ relative_fileloc=FILENAME1,
+ dag_id="dag_id1",
+ is_paused=False,
+ bundle_name=BUNDLE_NAME,
+ )
+ dag_model2 = DagModel(
+ fileloc=FILENAME2,
+ relative_fileloc=FILENAME2,
+ dag_id="dag_id2",
+ is_paused=False,
+ bundle_name=BUNDLE_NAME,
+ )
+ dag_model3 = DagModel(
+ fileloc=FILENAME3,
+ relative_fileloc=FILENAME3,
+ dag_id="dag_id3",
+ is_paused=False,
+ bundle_name=BUNDLE_NAME,
+ )
+ session.add(dag_model1)
+ session.add(dag_model2)
+ session.add(dag_model3)
+ session.commit()
+ return {dag_model1.dag_id, dag_model2.dag_id, dag_model3.dag_id}
+
+
@pytest.fixture
@provide_session
def not_permitted_dag_model(testing_dag_bundle, session: Session =
NEW_SESSION) -> DagModel:
@@ -105,7 +136,7 @@ def import_errors(session: Session = NEW_SESSION) ->
list[ParseImportError]:
timestamp=timestamp,
)
for bundle, filename, stacktrace, timestamp in zip(
- (BUNDLE_NAME, BUNDLE_NAME, None),
+ (BUNDLE_NAME, BUNDLE_NAME, BUNDLE_NAME),
(FILENAME1, FILENAME2, FILENAME3),
(STACKTRACE1, STACKTRACE2, STACKTRACE3),
(TIMESTAMP1, TIMESTAMP2, TIMESTAMP3),
@@ -116,14 +147,6 @@ def import_errors(session: Session = NEW_SESSION) ->
list[ParseImportError]:
return _import_errors
-def set_mock_auth_manager__is_authorized_dag(
- mock_auth_manager: mock.Mock, is_authorized_dag_return_value: bool = False
-) -> mock.Mock:
- mock_is_authorized_dag = mock_auth_manager.return_value.is_authorized_dag
- mock_is_authorized_dag.return_value = is_authorized_dag_return_value
- return mock_is_authorized_dag
-
-
def set_mock_auth_manager__get_authorized_dag_ids(
mock_auth_manager: mock.Mock, get_authorized_dag_ids_return_value:
set[str] | None = None
) -> mock.Mock:
@@ -173,19 +196,28 @@ class TestGetImportError:
"timestamp": from_datetime_to_zulu_without_ms(TIMESTAMP3),
"filename": FILENAME3,
"stack_trace": STACKTRACE3,
- "bundle_name": None,
+ "bundle_name": BUNDLE_NAME,
},
),
(None, 404, {}),
],
)
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_error(
- self, prepared_import_error_idx, expected_status_code, expected_body,
test_client, import_errors
+ self,
+ mock_get_auth_manager,
+ prepared_import_error_idx,
+ expected_status_code,
+ expected_body,
+ test_client,
+ permitted_dag_model_all,
+ import_errors,
):
import_error: ParseImportError | None = (
import_errors[prepared_import_error_idx] if
prepared_import_error_idx is not None else None
)
import_error_id = import_error.id if import_error else
IMPORT_ERROR_NON_EXISTED_ID
+ set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager,
permitted_dag_model_all)
response = test_client.get(f"/importErrors/{import_error_id}")
assert response.status_code == expected_status_code
if expected_status_code != 200:
@@ -210,23 +242,25 @@ class TestGetImportError:
):
import_error_id = import_errors[0].id
# Mock auth_manager
- mock_is_authorized_dag =
set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids =
set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager)
# Act
response = test_client.get(f"/importErrors/{import_error_id}")
# Assert
- mock_is_authorized_dag.assert_called_once_with(method="GET",
user=mock.ANY)
mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY)
assert response.status_code == 403
assert response.json() == {"detail": "You do not have read permission
on any of the DAGs in the file"}
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def
test_get_import_error__user_dont_have_read_permission_to_read_all_dags_in_file(
- self, mock_get_auth_manager, test_client, permitted_dag_model,
not_permitted_dag_model, import_errors
+ self,
+ mock_get_auth_manager,
+ test_client,
+ permitted_dag_model_all,
+ not_permitted_dag_model,
+ import_errors,
):
import_error_id = import_errors[0].id
- set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
- set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager,
{permitted_dag_model.dag_id})
+ set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager,
permitted_dag_model_all)
# Act
response = test_client.get(f"/importErrors/{import_error_id}")
# Assert
@@ -316,15 +350,21 @@ class TestGetImportErrors:
),
],
)
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.import_error.get_auth_manager")
def test_get_import_errors(
self,
+ mock_get_auth_manager,
test_client,
query_params,
expected_status_code,
expected_total_entries,
expected_filenames,
+ permitted_dag_model_all,
):
- with assert_queries_count(2):
+ set_mock_auth_manager__get_authorized_dag_ids(mock_get_auth_manager,
permitted_dag_model_all)
+ set_mock_auth_manager__batch_is_authorized_dag(mock_get_auth_manager,
True)
+
+ with assert_queries_count(5):
response = test_client.get("/importErrors", params=query_params)
assert response.status_code == expected_status_code
@@ -380,7 +420,6 @@ class TestGetImportErrors:
import_errors,
):
mock_get_dag_id_to_team_name_mapping.return_value =
{permitted_dag_model.dag_id: team}
- set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids =
set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
)
@@ -422,7 +461,6 @@ class TestGetImportErrors:
self, mock_get_auth_manager, test_client, permitted_dag_model,
import_errors, session
):
"""Test that the bundle_name join condition works correctly."""
- set_mock_auth_manager__is_authorized_dag(mock_get_auth_manager)
mock_get_authorized_dag_ids =
set_mock_auth_manager__get_authorized_dag_ids(
mock_get_auth_manager, {permitted_dag_model.dag_id}
)