This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch revert-47062-feature/AIP-84/add-auth-for-configuration in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 8275338e35877a1178b6ed170990b49c289650f1 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Mar 5 17:26:50 2025 +0000 Revert "AIP-84 | Add Auth for Dags (#47062)" This reverts commit 6107fcef99112f04d814ebc8f19b110909914e09. --- airflow/api_fastapi/common/db/common.py | 52 ++++++++-------- airflow/api_fastapi/common/parameters.py | 11 ++-- airflow/api_fastapi/core_api/base.py | 23 ------- .../api_fastapi/core_api/openapi/v1-generated.yaml | 12 ---- airflow/api_fastapi/core_api/routes/public/dags.py | 17 +---- airflow/api_fastapi/core_api/security.py | 56 ++--------------- .../core_api/routes/public/test_dags.py | 72 +--------------------- tests/api_fastapi/core_api/test_security.py | 8 +-- 8 files changed, 44 insertions(+), 207 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 363a7fbdc4b..84297724ce0 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -35,7 +35,7 @@ from airflow.utils.session import NEW_SESSION, create_session, create_session_as if TYPE_CHECKING: from sqlalchemy.sql import Select - from airflow.api_fastapi.core_api.base import OrmClause + from airflow.api_fastapi.common.parameters import BaseParam def _get_session() -> Session: @@ -47,7 +47,7 @@ SessionDep = Annotated[Session, Depends(_get_session)] def apply_filters_to_select( - *, statement: Select, filters: Sequence[OrmClause | None] | None = None + *, statement: Select, filters: Sequence[BaseParam | None] | None = None ) -> Select: if filters is None: return statement @@ -71,10 +71,10 @@ AsyncSessionDep = Annotated[AsyncSession, Depends(_get_async_session)] async def paginated_select_async( *, statement: Select, - filters: Sequence[OrmClause] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: AsyncSession, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -84,10 +84,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[OrmClause] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: AsyncSession, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -96,10 +96,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[OrmClause | None] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: AsyncSession, return_total_entries: bool = True, ) -> tuple[Select, int | None]: @@ -129,10 +129,10 @@ async def paginated_select_async( def paginated_select( *, statement: Select, - filters: Sequence[OrmClause] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -142,10 +142,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[OrmClause] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -155,10 +155,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[OrmClause] | None = None, - order_by: OrmClause | None = None, - offset: OrmClause | None = None, - limit: OrmClause | None = None, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, session: Session = NEW_SESSION, return_total_entries: bool = True, ) -> tuple[Select, int | None]: diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index 887cb03244c..55586840fb0 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -40,7 +40,6 @@ from pydantic import AfterValidator, BaseModel, NonNegativeInt from sqlalchemy import Column, and_, case, or_ from sqlalchemy.inspection import inspect -from airflow.api_fastapi.core_api.base import OrmClause from airflow.models import Base from airflow.models.asset import ( AssetAliasModel, @@ -65,14 +64,18 @@ if TYPE_CHECKING: T = TypeVar("T") -class BaseParam(OrmClause[T], ABC): - """Base class for path or query parameters with ORM transformation.""" +class BaseParam(Generic[T], ABC): + """Base class for filters.""" def __init__(self, value: T | None = None, skip_none: bool = True) -> None: - super().__init__(value) + self.value = value self.attribute: ColumnElement | None = None self.skip_none = skip_none + @abstractmethod + def to_orm(self, select: Select) -> Select: + pass + def set_value(self, value: T | None) -> Self: self.value = value return self diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py index 887f528f197..d88ec1757eb 100644 --- a/airflow/api_fastapi/core_api/base.py +++ b/airflow/api_fastapi/core_api/base.py @@ -16,16 +16,8 @@ # under the License. from __future__ import annotations -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar - from pydantic import BaseModel as PydanticBaseModel, ConfigDict -if TYPE_CHECKING: - from sqlalchemy.sql import Select - -T = TypeVar("T") - class BaseModel(PydanticBaseModel): """ @@ -47,18 +39,3 @@ class StrictBaseModel(BaseModel): """ model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid") - - -class OrmClause(Generic[T], ABC): - """ - Base class for filtering clauses with paginated_select. - - The subclasses should implement the `to_orm` method and set the `value` attribute. - """ - - def __init__(self, value: T | None = None): - self.value = value - - @abstractmethod - def to_orm(self, select: Select) -> Select: - pass diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index bb1fb080653..63dd0b7023f 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3014,8 +3014,6 @@ paths: summary: Get Dags description: Get all DAGs. operationId: get_dags - security: - - OAuth2PasswordBearer: [] parameters: - name: limit in: query @@ -3181,8 +3179,6 @@ paths: summary: Patch Dags description: Patch multiple DAGs. operationId: patch_dags - security: - - OAuth2PasswordBearer: [] parameters: - name: update_mask in: query @@ -3318,8 +3314,6 @@ paths: summary: Get Dag description: Get basic information about a DAG. operationId: get_dag - security: - - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3370,8 +3364,6 @@ paths: summary: Patch Dag description: Patch the specific DAG. operationId: patch_dag - security: - - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3438,8 +3430,6 @@ paths: summary: Delete Dag description: Delete the specific DAG. operationId: delete_dag - security: - - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3490,8 +3480,6 @@ paths: summary: Get Dag Details description: Get details of DAG. operationId: get_dag_details - security: - - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index f38f650657c..3eaf4879482 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -57,11 +57,6 @@ from airflow.api_fastapi.core_api.datamodels.dags import ( DAGResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import ( - EditableDagsFilterDep, - ReadableDagsFilterDep, - requires_access_dag, -) from airflow.exceptions import AirflowException, DagNotFound from airflow.models import DAG, DagModel from airflow.models.dagrun import DagRun @@ -69,7 +64,7 @@ from airflow.models.dagrun import DagRun dags_router = AirflowRouter(tags=["DAG"], prefix="/dags") -@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))]) +@dags_router.get("") def get_dags( limit: QueryLimit, offset: QueryOffset, @@ -109,7 +104,6 @@ def get_dags( ).dynamic_depends() ), ], - readable_dags_filter: ReadableDagsFilterDep, session: SessionDep, ) -> DAGCollectionResponse: """Get all DAGs.""" @@ -137,7 +131,6 @@ def get_dags( tags, owners, last_dag_run_state, - readable_dags_filter, ], order_by=order_by, offset=offset, @@ -162,7 +155,6 @@ def get_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), - dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: """Get basic information about a DAG.""" @@ -189,7 +181,6 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse: """Get details of DAG.""" @@ -216,7 +207,6 @@ def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDe status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(requires_access_dag(method="PUT"))], ) def patch_dag( dag_id: str, @@ -259,7 +249,6 @@ def patch_dag( status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(requires_access_dag(method="PUT"))], ) def patch_dags( patch_body: DAGPatchBody, @@ -271,7 +260,6 @@ def patch_dags( only_active: QueryOnlyActiveFilter, paused: QueryPausedFilter, last_dag_run_state: QueryLastDagRunStateFilter, - editable_dags_filter: EditableDagsFilterDep, session: SessionDep, update_mask: list[str] | None = Query(None), ) -> DAGCollectionResponse: @@ -292,7 +280,7 @@ def patch_dags( dags_select, total_entries = paginated_select( statement=generate_dag_with_latest_run_query(), - filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state, editable_dags_filter], + filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], order_by=None, offset=offset, limit=limit, @@ -322,7 +310,6 @@ def patch_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), - dependencies=[Depends(requires_access_dag(method="DELETE"))], ) def delete_dag( dag_id: str, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 1dda21e2e28..1b000afa72c 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -from collections.abc import Container from functools import cache from typing import TYPE_CHECKING, Annotated, Callable @@ -25,7 +24,6 @@ from fastapi.security import OAuth2PasswordBearer from jwt import ExpiredSignatureError, InvalidTokenError from airflow.api_fastapi.app import get_auth_manager -from airflow.api_fastapi.core_api.base import OrmClause from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( ConnectionDetails, @@ -35,13 +33,10 @@ from airflow.auth.managers.models.resource_details import ( VariableDetails, ) from airflow.configuration import conf -from airflow.models.dag import DagModel from airflow.utils.jwt_signer import JWTSigner, get_signing_key if TYPE_CHECKING: - from sqlalchemy.sql import Select - - from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod + from airflow.auth.managers.base_auth_manager import ResourceMethod oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -64,9 +59,6 @@ def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser: raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden") -GetUserDep = Annotated[BaseUser, Depends(get_user)] - - async def get_user_with_exception_handling(request: Request) -> BaseUser | None: # Currently the UI does not support JWT authentication, this method defines a fallback if no token is provided by the UI. # We can remove this method when issue https://github.com/apache/airflow/issues/44884 is done. @@ -84,15 +76,11 @@ async def get_user_with_exception_handling(request: Request) -> BaseUser | None: return get_user(token_str) -def requires_access_dag( - method: ResourceMethod, access_entity: DagAccessEntity | None = None -) -> Callable[[Request, BaseUser], None]: +def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable: def inner( - request: Request, - user: GetUserDep, + user: Annotated[BaseUser, Depends(get_user)], + dag_id: str | None = None, ) -> None: - dag_id: str | None = request.path_params.get("dag_id") - _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( method=method, access_entity=access_entity, details=DagDetails(id=dag_id), user=user @@ -102,42 +90,10 @@ def requires_access_dag( return inner -class PermittedDagFilter(OrmClause[set[str]]): - """A parameter that filters the permitted dags for the user.""" - - def to_orm(self, select: Select) -> Select: - return select.where(DagModel.dag_id.in_(self.value)) - - -def permitted_dag_filter_factory( - methods: Container[ResourceMethod], -) -> Callable[[Request, BaseUser], PermittedDagFilter]: - """ - Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. - - :param methods: whether filter readable or writable. - :return: The callable that can be used as Depends in FastAPI. - """ - - def depends_permitted_dags_filter( - request: Request, - user: GetUserDep, - ) -> PermittedDagFilter: - auth_manager: BaseAuthManager = request.app.state.auth_manager - permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, methods=methods) - return PermittedDagFilter(permitted_dags) - - return depends_permitted_dags_filter - - -EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["PUT"]))] -ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["GET"]))] - - def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: GetUserDep, + user: Annotated[BaseUser, Depends(get_user)], ) -> None: pool_name = request.path_params.get("pool_name") @@ -153,7 +109,7 @@ def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser] def requires_access_connection(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: GetUserDep, + user: Annotated[BaseUser, Depends(get_user)], ) -> None: connection_id = request.path_params.get("connection_id") diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index c1793ed7fa5..9c4650636b2 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -234,31 +234,13 @@ class TestGetDags(TestDagEndpoint): ) def test_get_dags(self, test_client, query_params, expected_total_entries, expected_ids): response = test_client.get("/public/dags", params=query_params) + assert response.status_code == 200 body = response.json() assert body["total_entries"] == expected_total_entries assert [dag["dag_id"] for dag in body["dags"]] == expected_ids - @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") - def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): - mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} - response = test_client.get("/public/dags") - mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["GET"]) - assert response.status_code == 200 - body = response.json() - - assert body["total_entries"] == 2 - assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] - - def test_get_dags_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.get("/public/dags") - assert response.status_code == 401 - - def test_get_dags_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.get("/public/dags") - assert response.status_code == 403 - class TestPatchDag(TestDagEndpoint): """Unit tests for Patch DAG.""" @@ -284,14 +266,6 @@ class TestPatchDag(TestDagEndpoint): body = response.json() assert body["is_paused"] == expected_is_paused - def test_patch_dag_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) - assert response.status_code == 401 - - def test_patch_dag_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) - assert response.status_code == 403 - class TestPatchDags(TestDagEndpoint): """Unit tests for Patch DAGs.""" @@ -349,26 +323,6 @@ class TestPatchDags(TestDagEndpoint): paused_dag_ids = [dag["dag_id"] for dag in body["dags"] if dag["is_paused"]] assert paused_dag_ids == expected_paused_ids - @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") - def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): - mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} - response = test_client.patch( - "/public/dags", json={"is_paused": False}, params={"only_active": False, "dag_id_pattern": "~"} - ) - mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["PUT"]) - assert response.status_code == 200 - body = response.json() - - assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] - - def test_patch_dags_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.patch("/public/dags", json={"is_paused": True}) - assert response.status_code == 401 - - def test_patch_dags_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.patch("/public/dags", json={"is_paused": True}) - assert response.status_code == 403 - class TestDagDetails(TestDagEndpoint): """Unit tests for DAG Details.""" @@ -450,14 +404,6 @@ class TestDagDetails(TestDagEndpoint): } assert res_json == expected - def test_dag_details_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/details") - assert response.status_code == 401 - - def test_dag_details_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/details") - assert response.status_code == 403 - class TestGetDag(TestDagEndpoint): """Unit tests for Get DAG.""" @@ -506,14 +452,6 @@ class TestGetDag(TestDagEndpoint): } assert res_json == expected - def test_get_dag_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}") - assert response.status_code == 401 - - def test_get_dag_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}") - assert response.status_code == 403 - class TestDeleteDAG(TestDagEndpoint): """Unit tests for Delete DAG.""" @@ -572,11 +510,3 @@ class TestDeleteDAG(TestDagEndpoint): details_response = test_client.get(f"{API_PREFIX}/{dag_id}/details") assert details_response.status_code == status_code_details - - def test_delete_dag_should_response_401(self, unauthenticated_test_client): - response = unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") - assert response.status_code == 401 - - def test_delete_dag_should_response_403(self, unauthorized_test_client): - response = unauthorized_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") - assert response.status_code == 403 diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py index 2237e3bce1c..7824ecd171b 100644 --- a/tests/api_fastapi/core_api/test_security.py +++ b/tests/api_fastapi/core_api/test_security.py @@ -88,10 +88,8 @@ class TestFastApiSecurity: auth_manager = Mock() auth_manager.is_authorized_dag.return_value = True mock_get_auth_manager.return_value = auth_manager - fastapi_request = Mock() - fastapi_request.path_params.return_value = {} - requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock()) auth_manager.is_authorized_dag.assert_called_once() @@ -100,10 +98,8 @@ class TestFastApiSecurity: auth_manager = Mock() auth_manager.is_authorized_dag.return_value = False mock_get_auth_manager.return_value = auth_manager - fastapi_request = Mock() - fastapi_request.path_params.return_value = {} with pytest.raises(HTTPException, match="Forbidden"): - requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock()) auth_manager.is_authorized_dag.assert_called_once()
