vincbeck commented on code in PR #37468:
URL: https://github.com/apache/airflow/pull/37468#discussion_r1491829265
##########
airflow/api_connexion/endpoints/import_error_endpoint.py:
##########
@@ -16,39 +16,59 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Sequence
from sqlalchemy import func, select
from airflow.api_connexion import security
-from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.parameters import apply_sorting, check_limit,
format_parameters
from airflow.api_connexion.schemas.error_schema import (
ImportErrorCollection,
import_error_collection_schema,
import_error_schema,
)
-from airflow.auth.managers.models.resource_details import AccessView
+from airflow.auth.managers.models.resource_details import AccessView,
DagDetails
+from airflow.models.dag import DagModel
from airflow.models.errors import ImportError as ImportErrorModel
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.www.extensions.init_auth_manager import get_auth_manager
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.api_connexion.types import APIResponse
+ from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
@security.requires_access_view(AccessView.IMPORT_ERRORS)
@provide_session
def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION)
-> APIResponse:
"""Get an import error."""
error = session.get(ImportErrorModel, import_error_id)
-
if error is None:
raise NotFound(
"Import error not found",
detail=f"The ImportError with import_error_id: `{import_error_id}`
was not found",
)
+ session.expunge(error)
+
+ can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
+ if not can_read_all_dags:
+ readable_dag_ids = security.get_readable_dags()
+ file_dag_ids = {
+ dag_id[0]
+ for dag_id in
session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
+ }
Review Comment:
We do already this check here:
https://github.com/apache/airflow/blob/main/airflow/auth/managers/base_auth_manager.py#L364.
I dont think it is necessary to do it twice. I would just do
```suggestion
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
dag_id[0]
for dag_id in
session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
}
....
```
##########
airflow/api_connexion/endpoints/import_error_endpoint.py:
##########
@@ -65,10 +85,42 @@ def get_import_errors(
"""Get all import errors."""
to_replace = {"import_error_id": "id"}
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
- total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
+ count_query = select(func.count(ImportErrorModel.id))
query = select(ImportErrorModel)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
+
+ can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
+
+ if not can_read_all_dags:
Review Comment:
Same
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]