This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new d72131f952 Use auth manager `is_authorized_` APIs to check user permissions in Rest API (#34317) d72131f952 is described below commit d72131f952836a3134c90805ef7c3bcf82ea93e9 Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Tue Oct 17 12:43:09 2023 -0400 Use auth manager `is_authorized_` APIs to check user permissions in Rest API (#34317) --- airflow/api_connexion/endpoints/config_endpoint.py | 7 +- .../api_connexion/endpoints/connection_endpoint.py | 12 +- airflow/api_connexion/endpoints/dag_endpoint.py | 18 +- .../api_connexion/endpoints/dag_run_endpoint.py | 74 ++----- .../api_connexion/endpoints/dag_source_endpoint.py | 4 +- .../endpoints/dag_warning_endpoint.py | 9 +- .../api_connexion/endpoints/dataset_endpoint.py | 9 +- .../api_connexion/endpoints/event_log_endpoint.py | 6 +- .../api_connexion/endpoints/extra_link_endpoint.py | 10 +- .../endpoints/import_error_endpoint.py | 6 +- airflow/api_connexion/endpoints/log_endpoint.py | 10 +- airflow/api_connexion/endpoints/plugin_endpoint.py | 3 +- airflow/api_connexion/endpoints/pool_endpoint.py | 11 +- .../api_connexion/endpoints/provider_endpoint.py | 3 +- airflow/api_connexion/endpoints/task_endpoint.py | 16 +- .../endpoints/task_instance_endpoint.py | 84 ++------ .../api_connexion/endpoints/variable_endpoint.py | 10 +- airflow/api_connexion/endpoints/xcom_endpoint.py | 25 +-- airflow/api_connexion/security.py | 217 ++++++++++++++++++++- airflow/auth/managers/base_auth_manager.py | 82 +++++++- airflow/auth/managers/fab/decorators/auth.py | 30 +++ airflow/auth/managers/fab/fab_auth_manager.py | 187 ++++++++++++++---- .../auth/managers/fab/security_manager/override.py | 122 +++++++++++- airflow/auth/managers/models/resource_details.py | 36 +++- airflow/www/auth.py | 12 +- airflow/www/extensions/init_jinja_globals.py | 4 +- airflow/www/security_manager.py | 181 +---------------- airflow/www/templates/airflow/dag.html | 7 +- airflow/www/views.py | 66 ++++--- .../endpoints/test_event_log_endpoint.py | 17 +- tests/api_connexion/endpoints/test_log_endpoint.py | 3 +- .../api_connexion/endpoints/test_xcom_endpoint.py | 4 - tests/auth/managers/fab/test_fab_auth_manager.py | 32 ++- tests/auth/managers/test_base_auth_manager.py | 37 +++- tests/www/test_security.py | 84 +++++--- tests/www/views/test_views_acl.py | 5 + tests/www/views/test_views_decorators.py | 44 +---- tests/www/views/test_views_tasks.py | 1 + 38 files changed, 887 insertions(+), 601 deletions(-) diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py index a6fc67beb7..cbb8acdfce 100644 --- a/airflow/api_connexion/endpoints/config_endpoint.py +++ b/airflow/api_connexion/endpoints/config_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound, PermissionDenied from airflow.api_connexion.schemas.config_schema import Config, ConfigOption, ConfigSection, config_schema from airflow.configuration import conf -from airflow.security import permissions from airflow.settings import json LINE_SEP = "\n" # `\n` cannot appear in f-strings @@ -66,7 +65,7 @@ def _config_to_json(config: Config) -> str: return json.dumps(config_schema.dump(config), indent=4) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) +@security.requires_access_configuration("GET") def get_config(*, section: str | None = None) -> Response: """Get current configuration.""" serializer = { @@ -103,8 +102,8 @@ def get_config(*, section: str | None = None) -> Response: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) -def get_value(section: str, option: str) -> Response: +@security.requires_access_configuration("GET") +def get_value(*, section: str, option: str) -> Response: serializer = { "text/plain": _config_to_text, "application/json": _config_to_json, diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index 1444421a84..16d9afb5b9 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: RESOURCE_EVENT_PREFIX = "connection" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("DELETE") @provide_session @action_logging( event=action_event_from_permission( @@ -73,7 +73,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @provide_session def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a connection entry.""" @@ -86,7 +86,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @format_parameters({"limit": check_limit}) @provide_session def get_connections( @@ -109,7 +109,7 @@ def get_connections( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -147,7 +147,7 @@ def patch_connection( return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") @provide_session @action_logging( event=action_event_from_permission( @@ -176,7 +176,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}") -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") def test_connection() -> APIResponse: """ Test an API connection. diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 5aac030ecb..21a61a0ddd 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -36,10 +36,10 @@ from airflow.api_connexion.schemas.dag_schema import ( ) from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -48,7 +48,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") @provide_session def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get basic information about a DAG.""" @@ -60,7 +60,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -69,7 +69,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse: return dag_detail_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") @format_parameters({"limit": check_limit}) @provide_session def get_dags( @@ -96,7 +96,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) if tags: @@ -110,7 +110,7 @@ def get_dags( return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("PUT") @provide_session def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: """Update the specific DAG.""" @@ -132,7 +132,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("PUT") @format_parameters({"limit": check_limit}) @provide_session def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): @@ -156,7 +156,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat if dag_id_pattern == "~": dag_id_pattern = "%" dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) + editable_dags = get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=g.user) dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags)) if tags: @@ -180,7 +180,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("DELETE") @provide_session def delete_dag(dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete the specific DAG.""" diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 1a9cb03418..45e064764c 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -56,6 +56,7 @@ from airflow.api_connexion.schemas.task_instance_schema import ( TaskInstanceReferenceCollection, task_instance_reference_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app @@ -76,12 +77,7 @@ if TYPE_CHECKING: RESOURCE_EVENT_PREFIX = "dag_run" -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("DELETE", DagAccessEntity.RUN) @provide_session def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a DAG Run.""" @@ -93,12 +89,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a DAG Run.""" @@ -111,13 +102,8 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) +@security.requires_access_dataset("GET") @provide_session def get_upstream_dataset_events( *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION @@ -194,12 +180,7 @@ def _fetch_dag_runs( return session.scalars(query.offset(offset).limit(limit)).all(), total_entries -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @format_parameters( { "start_date_gte": format_datetime, @@ -236,8 +217,9 @@ def get_dag_runs( # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": - appbuilder = get_airflow_app().appbuilder - query = query.where(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) + query = query.where( + DagRun.dag_id.in_(get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)) + ) else: query = query.where(DagRun.dag_id == dag_id) @@ -262,12 +244,7 @@ def get_dag_runs( return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs.""" @@ -277,8 +254,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err.messages)) - appbuilder = get_airflow_app().appbuilder - readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) + readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user) query = select(DagRun) if data.get("dag_ids"): dag_ids = set(data["dag_ids"]) & set(readable_dag_ids) @@ -307,12 +283,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("POST", DagAccessEntity.RUN) @provide_session @action_logging( event=action_event_from_permission( @@ -378,12 +349,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" @@ -410,12 +376,7 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear a dag run.""" @@ -461,12 +422,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set the note for a dag run.""" diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index b191630815..3ee80ee857 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -24,11 +24,11 @@ from itsdangerous import BadSignature, URLSafeSerializer from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagcode import DagCode -from airflow.security import permissions -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)]) +@security.requires_access_dag("GET", DagAccessEntity.CODE) def get_dag_source(*, file_token: str) -> Response: """Get source code using file token.""" secret_key = current_app.config["SECRET_KEY"] diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index c9d8207b0f..3e0db58dc9 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -27,8 +27,8 @@ from airflow.api_connexion.schemas.dag_warning_schema import ( DagWarningCollection, dag_warning_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagwarning import DagWarning as DagWarningModel -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -39,12 +39,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - ] -) +@security.requires_access_dag("GET", DagAccessEntity.WARNING) @format_parameters({"limit": check_limit}) @provide_session def get_dag_warnings( diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index 81fe872fca..152ac6eecb 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -32,7 +32,6 @@ from airflow.api_connexion.schemas.dataset_schema import ( dataset_schema, ) from airflow.models.dataset import DatasetEvent, DatasetModel -from airflow.security import permissions from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -42,9 +41,9 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session -def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: """Get a Dataset.""" dataset = session.scalar( select(DatasetModel) @@ -59,7 +58,7 @@ def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: return dataset_schema.dump(dataset) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @format_parameters({"limit": check_limit}) @provide_session def get_datasets( @@ -86,7 +85,7 @@ def get_datasets( return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session @format_parameters({"limit": check_limit}) def get_dataset_events( diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index 99ec8eedae..b5bca5cc23 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -28,8 +28,8 @@ from airflow.api_connexion.schemas.event_log_schema import ( event_log_collection_schema, event_log_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import Log -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session @@ -40,7 +40,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @provide_session def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get a log entry.""" @@ -50,7 +50,7 @@ def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIRe return event_log_schema.dump(event_log) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @format_parameters({"limit": check_limit}) @provide_session def get_event_logs( diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index ec92dd51ee..2e9954587c 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -22,8 +22,8 @@ from sqlalchemy import select from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -35,13 +35,7 @@ if TYPE_CHECKING: from airflow.models.dagbag import DagBag -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_extra_links( *, diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index f2b9a88311..81459b604e 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -28,8 +28,8 @@ from airflow.api_connexion.schemas.error_schema import ( import_error_collection_schema, import_error_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.errors import ImportError as ImportErrorModel -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -38,7 +38,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @provide_session def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get an import error.""" @@ -52,7 +52,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> return import_error_schema.dump(error) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @format_parameters({"limit": check_limit}) @provide_session def get_import_errors( diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 126b8634e3..239f08ecda 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -27,9 +27,9 @@ from sqlalchemy.orm import joinedload from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance, Trigger -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -40,13 +40,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_LOGS) @provide_session def get_log( *, diff --git a/airflow/api_connexion/endpoints/plugin_endpoint.py b/airflow/api_connexion/endpoints/plugin_endpoint.py index 02ba435d52..500bd65749 100644 --- a/airflow/api_connexion/endpoints/plugin_endpoint.py +++ b/airflow/api_connexion/endpoints/plugin_endpoint.py @@ -22,13 +22,12 @@ from airflow.api_connexion import security from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.plugin_schema import PluginCollection, plugin_collection_schema from airflow.plugins_manager import get_plugin_info -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)]) +@security.requires_access_website() @format_parameters({"limit": check_limit}) def get_plugins(*, limit: int, offset: int = 0) -> APIResponse: """Get plugins endpoint.""" diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 735d777e4c..0fbb2c8a23 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -30,7 +30,6 @@ from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -39,7 +38,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("DELETE") @provide_session def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a pool.""" @@ -52,7 +51,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @provide_session def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Get a pool.""" @@ -62,7 +61,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: return pool_schema.dump(obj) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @format_parameters({"limit": check_limit}) @provide_session def get_pools( @@ -82,7 +81,7 @@ def get_pools( return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("PUT") @provide_session def patch_pool( *, @@ -138,7 +137,7 @@ def patch_pool( return pool_schema.dump(pool) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("POST") @provide_session def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: """Create a pool.""" diff --git a/airflow/api_connexion/endpoints/provider_endpoint.py b/airflow/api_connexion/endpoints/provider_endpoint.py index 75bba31218..a64368dce3 100644 --- a/airflow/api_connexion/endpoints/provider_endpoint.py +++ b/airflow/api_connexion/endpoints/provider_endpoint.py @@ -27,7 +27,6 @@ from airflow.api_connexion.schemas.provider_schema import ( provider_collection_schema, ) from airflow.providers_manager import ProvidersManager -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse @@ -46,7 +45,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)]) +@security.requires_access_website() def get_providers() -> APIResponse: """Get providers.""" providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()] diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 70b6e4b8ab..4c5954d2ac 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -31,12 +31,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -50,12 +45,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: return task_schema.dump(task) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: """Get tasks for DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 612167d3d7..0b942134ac 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -42,11 +42,11 @@ from airflow.api_connexion.schemas.task_instance_schema import ( task_instance_schema, ) from airflow.api_connexion.security import get_readable_dags +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -62,13 +62,7 @@ if TYPE_CHECKING: T = TypeVar("T") -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instance( *, @@ -110,13 +104,7 @@ def get_task_instance( return task_instance_schema.dump(task_instance) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instance( *, @@ -162,13 +150,7 @@ def get_mapped_task_instance( "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instances( *, @@ -306,13 +288,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances( *, @@ -389,13 +365,7 @@ def get_task_instances( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" @@ -408,7 +378,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: if dag_ids: cannot_access_dag_ids = set() for id in dag_ids: - if not get_airflow_app().appbuilder.sm.can_read_dag(id, g.user): + if not get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=id), user=g.user): cannot_access_dag_ids.add(id) if cannot_access_dag_ids: raise PermissionDenied( @@ -464,13 +434,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" @@ -530,13 +494,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" @@ -603,13 +561,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION @@ -649,13 +601,7 @@ def patch_task_instance( return task_instance_reference_schema.dump(ti) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_mapped_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION @@ -666,13 +612,7 @@ def patch_mapped_task_instance( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def set_task_instance_note( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 54d5ac744b..05157298e7 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -43,7 +43,7 @@ if TYPE_CHECKING: RESOURCE_EVENT_PREFIX = "variable" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, @@ -57,7 +57,7 @@ def delete_variable(*, variable_key: str) -> Response: return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @provide_session def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Response: """Get a variable by key.""" @@ -67,7 +67,7 @@ def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Respon return variable_schema.dump(var) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("GET") @format_parameters({"limit": check_limit}) @provide_session def get_variables( @@ -92,7 +92,7 @@ def get_variables( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -126,7 +126,7 @@ def patch_variable( return variable_schema.dump(variable) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("POST") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 73bdd8562e..d5eb6ed19b 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -26,12 +26,12 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.xcom_schema import XComCollection, xcom_collection_schema, xcom_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagRun as DR, XCom -from airflow.security import permissions from airflow.settings import conf -from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -39,14 +39,7 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @format_parameters({"limit": check_limit}) @provide_session def get_xcom_entries( @@ -63,8 +56,7 @@ def get_xcom_entries( """Get all XCom values.""" query = select(XCom) if dag_id == "~": - appbuilder = get_airflow_app().appbuilder - readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) + readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user) query = query.where(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) else: @@ -85,14 +77,7 @@ def get_xcom_entries( return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @provide_session def get_xcom_entry( *, diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index b19f15257c..6da171aa62 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -16,13 +16,28 @@ # under the License. from __future__ import annotations +import warnings from functools import wraps -from typing import Callable, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Sequence, TypeVar, cast from flask import Response, g from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) +from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.airflow_flask_app import get_airflow_app +from airflow.www.extensions.init_auth_manager import get_auth_manager + +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod T = TypeVar("T", bound=Callable) @@ -39,18 +54,202 @@ def check_authentication() -> None: def requires_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: - """Check current user's permissions against required permissions.""" - appbuilder = get_airflow_app().appbuilder - if appbuilder.update_perms: - appbuilder.sm.sync_resource_permissions(permissions) + """ + Check current user's permissions against required permissions. + + Deprecated. Do not use this decorator, use one of the decorator `has_access_*` defined in + airflow/api_connexion/security.py instead. + This decorator will only work with FAB authentication and not with other auth providers. + + This decorator might be used in user plugins, do not remove it. + """ + warnings.warn( + "The 'requires_access' decorator is deprecated. Please use one of the decorator `requires_access_*`" + "defined in airflow/api_connexion/security.py instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + from airflow.auth.managers.fab.decorators.auth import _requires_access_fab + + return _requires_access_fab(permissions) + + +def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs) -> bool: + """ + Define the behavior whether the user is authorized to access the resource. + + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access + the resource + :param func: the function to call if the user is authorized + :param args: the arguments of ``func`` + :param kwargs: the keyword arguments ``func`` + + :meta private: + """ + check_authentication() + if is_authorized_callback(): + return func(*args, **kwargs) + raise PermissionDenied() + + +def requires_authentication(func: T): + """Decorator for functions that require authentication.""" + + @wraps(func) + def decorated(*args, **kwargs): + check_authentication() + return func(*args, **kwargs) + + return cast(T, decorated) + + +def requires_access_configuration(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + section: str | None = kwargs.get("section") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration( + method=method, details=ConfigurationDetails(section=section) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_connection(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + connection_id: str | None = kwargs.get("connection_id") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_connection( + method=method, details=ConnectionDetails(conn_id=connection_id) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dag( + method: ResourceMethod, access_entity: DagAccessEntity | None = None +) -> Callable[[T], T]: + def _is_authorized_callback(dag_id: str): + def callback(): + access = get_auth_manager().is_authorized_dag( + method=method, + access_entity=access_entity, + details=DagDetails(id=dag_id), + ) + + # ``access`` means here: + # - if a DAG id is provided (``dag_id`` not None): is the user authorized to access this DAG + # - if no DAG id is provided: is the user authorized to access all DAGs + if dag_id or access: + return access + + # No DAG id is provided and the user is not authorized to access all DAGs + # If method is "GET", return whether the user has read access to any DAGs + # If method is "PUT", return whether the user has edit access to any DAGs + return (method == "GET" and any(get_auth_manager().get_permitted_dag_ids(methods=["GET"]))) or ( + method == "PUT" and any(get_auth_manager().get_permitted_dag_ids(methods=["PUT"])) + ) + + return callback + + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + dag_id: str | None = kwargs.get("dag_id") if kwargs.get("dag_id") != "~" else None + return _requires_access( + is_authorized_callback=_is_authorized_callback(dag_id), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dataset(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + uri: str | None = kwargs.get("uri") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_dataset( + method=method, details=DatasetDetails(uri=uri) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_pool(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + pool_name: str | None = kwargs.get("pool_name") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_pool( + method=method, details=PoolDetails(name=pool_name) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_variable(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + variable_key: str | None = kwargs.get("variable_key") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( + method=method, details=VariableDetails(key=variable_key) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + +def requires_access_website() -> Callable[[T], T]: def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): - check_authentication() - if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): - return func(*args, **kwargs) - raise PermissionDenied() + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_website(), + func=func, + args=args, + kwargs=kwargs, + ) return cast(T, decorated) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 29695dae12..0700338069 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -18,19 +18,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Container, Literal + +from sqlalchemy import select from airflow.exceptions import AirflowException +from airflow.models import DagModel from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from flask import Flask + from sqlalchemy.orm import Session from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, ConnectionDetails, DagAccessEntity, DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, ) from airflow.cli.cli_config import CLICommand from airflow.www.security_manager import AirflowSecurityManagerV2 @@ -82,12 +91,14 @@ class BaseAuthManager(LoggingMixin): self, *, method: ResourceMethod, + details: ConfigurationDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on configuration. :param method: the method to perform + :param details: optional details about the configuration :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -110,14 +121,14 @@ class BaseAuthManager(LoggingMixin): self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a connection. :param method: the method to perform - :param connection_details: optional details about the connection + :param details: optional details about the connection :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -126,17 +137,17 @@ class BaseAuthManager(LoggingMixin): self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_access_entity: the kind of DAG information the authorization request is about. + :param access_entity: the kind of DAG information the authorization request is about. If not provided, the authorization request is about the DAG itself - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -145,12 +156,30 @@ class BaseAuthManager(LoggingMixin): self, *, method: ResourceMethod, + details: DatasetDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a dataset. :param method: the method to perform + :param details: optional details about the dataset + :param user: the user to perform the action on. If not provided (or None), it uses the current user + """ + + @abstractmethod + def is_authorized_pool( + self, + *, + method: ResourceMethod, + details: PoolDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + """ + Return whether the user is authorized to perform a given action on a pool. + + :param method: the method to perform + :param details: optional details about the pool :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -159,12 +188,14 @@ class BaseAuthManager(LoggingMixin): self, *, method: ResourceMethod, + details: VariableDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a variable. :param method: the method to perform + :param details: optional details about the variable :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -182,6 +213,43 @@ class BaseAuthManager(LoggingMixin): :param user: the user to perform the action on. If not provided (or None), it uses the current user """ + @provide_session + def get_permitted_dag_ids( + self, + *, + methods: Container[ResourceMethod] | None = None, + user=None, + session: Session = NEW_SESSION, + ) -> set[str]: + """ + Get readable or writable DAGs for user. + + By default, reads all the DAGs and check individually if the user has permissions to access the DAG. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + """ + if not methods: + methods = ["PUT", "GET"] + + dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + + if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( + "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) + ): + # If user is authorized to read/edit all DAGs, return all DAGs + return dag_ids + + def _is_permitted_dag_id(method: ResourceMethod, methods: Container[ResourceMethod], dag_id: str): + return method in methods and self.is_authorized_dag( + method=method, details=DagDetails(id=dag_id), user=user + ) + + return { + dag_id + for dag_id in dag_ids + if _is_permitted_dag_id("GET", methods, dag_id) or _is_permitted_dag_id("PUT", methods, dag_id) + } + @abstractmethod def get_url_login(self, **kwargs) -> str: """Return the login page url.""" diff --git a/airflow/auth/managers/fab/decorators/auth.py b/airflow/auth/managers/fab/decorators/auth.py index 5f0f161470..583e18e2a7 100644 --- a/airflow/auth/managers/fab/decorators/auth.py +++ b/airflow/auth/managers/fab/decorators/auth.py @@ -23,7 +23,10 @@ from typing import Callable, Sequence, TypeVar, cast from flask import current_app, render_template, request +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.api_connexion.security import check_authentication from airflow.configuration import conf +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.net import get_hostname from airflow.www.auth import _has_access from airflow.www.extensions.init_auth_manager import get_auth_manager @@ -33,6 +36,33 @@ T = TypeVar("T", bound=Callable) log = logging.getLogger(__name__) +def _requires_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: + """ + Check current user's permissions against required permissions. + + This decorator is only kept for backward compatible reasons. The decorator + ``airflow.api_connexion.security.requires_access``, which redirects to this decorator, might be used in + user plugins. Thus, we need to keep it. + + :meta private: + """ + appbuilder = get_airflow_app().appbuilder + if appbuilder.update_perms: + appbuilder.sm.sync_resource_permissions(permissions) + + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + check_authentication() + if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): + return func(*args, **kwargs) + raise PermissionDenied() + + return cast(T, decorated) + + return requires_access_decorator + + def _has_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: """ Factory for decorator that checks current user's permissions against required permissions. diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 9c2c5643b6..6c942babbf 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -18,10 +18,11 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Container from flask import url_for from sqlalchemy import select +from sqlalchemy.orm import Session, joinedload from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod from airflow.auth.managers.fab.cli_commands.definition import ( @@ -29,12 +30,22 @@ from airflow.auth.managers.fab.cli_commands.definition import ( SYNC_PERM_COMMAND, USERS_COMMANDS, ) -from airflow.auth.managers.models.resource_details import ConnectionDetails, DagAccessEntity, DagDetails +from airflow.auth.managers.fab.models import Permission, Role, User +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) from airflow.cli.cli_config import ( GroupCommand, ) from airflow.exceptions import AirflowException from airflow.models import DagModel +from airflow.security import permissions from airflow.security.permissions import ( ACTION_CAN_ACCESS_MENU, ACTION_CAN_CREATE, @@ -50,37 +61,51 @@ from airflow.security.permissions import ( RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_PREFIX, RESOURCE_DAG_RUN, + RESOURCE_DAG_WARNING, RESOURCE_DATASET, + RESOURCE_IMPORT_ERROR, + RESOURCE_PLUGIN, + RESOURCE_POOL, + RESOURCE_PROVIDER, RESOURCE_TASK_INSTANCE, RESOURCE_TASK_LOG, + RESOURCE_TRIGGER, RESOURCE_VARIABLE, RESOURCE_WEBSITE, RESOURCE_XCOM, ) +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from airflow.auth.managers.fab.models import User + from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, ) -_MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { +MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { "POST": ACTION_CAN_CREATE, "GET": ACTION_CAN_READ, "PUT": ACTION_CAN_EDIT, "DELETE": ACTION_CAN_DELETE, } -_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE = { - DagAccessEntity.AUDIT_LOG: RESOURCE_AUDIT_LOG, - DagAccessEntity.CODE: RESOURCE_DAG_CODE, - DagAccessEntity.DATASET: RESOURCE_DATASET, - DagAccessEntity.DEPENDENCIES: RESOURCE_DAG_DEPENDENCIES, - DagAccessEntity.RUN: RESOURCE_DAG_RUN, - DagAccessEntity.TASK_INSTANCE: RESOURCE_TASK_INSTANCE, - DagAccessEntity.TASK_LOGS: RESOURCE_TASK_LOG, - DagAccessEntity.XCOM: RESOURCE_XCOM, +_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ...]] = { + DagAccessEntity.AUDIT_LOG: (RESOURCE_AUDIT_LOG,), + DagAccessEntity.CODE: (RESOURCE_DAG_CODE,), + DagAccessEntity.DEPENDENCIES: (RESOURCE_DAG_DEPENDENCIES,), + DagAccessEntity.IMPORT_ERRORS: (RESOURCE_IMPORT_ERROR,), + DagAccessEntity.RUN: (RESOURCE_DAG_RUN,), + # RESOURCE_TASK_INSTANCE has been originally misused. RESOURCE_TASK_INSTANCE referred to task definition + # AND task instances without making the difference + # To be backward compatible, we translate DagAccessEntity.TASK_INSTANCE to RESOURCE_TASK_INSTANCE AND + # RESOURCE_DAG_RUN + # See https://github.com/apache/airflow/pull/34317#discussion_r1355917769 + DagAccessEntity.TASK: (RESOURCE_TASK_INSTANCE,), + DagAccessEntity.TASK_INSTANCE: (RESOURCE_DAG_RUN, RESOURCE_TASK_INSTANCE), + DagAccessEntity.TASK_LOGS: (RESOURCE_TASK_LOG,), + DagAccessEntity.WARNING: (RESOURCE_DAG_WARNING,), + DagAccessEntity.XCOM: (RESOURCE_XCOM,), } @@ -139,7 +164,13 @@ class FabAuthManager(BaseAuthManager): """Return whether the user is logged in.""" return not self.get_user().is_anonymous - def is_authorized_configuration(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONFIG, user=user) def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: @@ -149,7 +180,7 @@ class FabAuthManager(BaseAuthManager): self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONNECTION, user=user) @@ -158,8 +189,8 @@ class FabAuthManager(BaseAuthManager): self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ @@ -171,34 +202,111 @@ class FabAuthManager(BaseAuthManager): entity (e.g. DAG runs). 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). - a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity - b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity + a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity. + b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity. + + However, if no specific DAG is targeted, just check the sub entity. :param method: The method to authorize. - :param dag_access_entity: The dag access entity. - :param dag_details: The dag details. + :param access_entity: The dag access entity. + :param details: The dag details. :param user: The user. """ - if not dag_access_entity: + if not access_entity: # Scenario 1 - return self._is_authorized_dag(method=method, dag_details=dag_details, user=user) + return self._is_authorized_dag(method=method, details=details, user=user) else: # Scenario 2 - resource_type = self._get_fab_resource_type(dag_access_entity) + resource_types = self._get_fab_resource_types(access_entity) dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" - return self._is_authorized_dag( - method=dag_method, dag_details=dag_details, user=user - ) and self._is_authorized(method=method, resource_type=resource_type, user=user) + if (details and details.id) and not self._is_authorized_dag( + method=dag_method, details=details, user=user + ): + return False + + return all( + self._is_authorized(method=method, resource_type=resource_type, user=user) + for resource_type in resource_types + ) - def is_authorized_dataset(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_DATASET, user=user) - def is_authorized_variable(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self._is_authorized(method=method, resource_type=RESOURCE_POOL, user=user) + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_VARIABLE, user=user) def is_authorized_website(self, *, user: BaseUser | None = None) -> bool: - return self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + return ( + self._is_authorized(method="GET", resource_type=RESOURCE_PLUGIN, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_PROVIDER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_TRIGGER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + ) + + @provide_session + def get_permitted_dag_ids( + self, + *, + methods: Container[ResourceMethod] | None = None, + user=None, + session: Session = NEW_SESSION, + ) -> set[str]: + if not methods: + methods = ["PUT", "GET"] + + if not user: + user = self.get_user() + + if not self.is_logged_in(): + roles = user.roles + else: + if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( + "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) + ): + # If user is authorized to read/edit all DAGs, return all DAGs + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + user_query = session.scalar( + select(User) + .options( + joinedload(User.roles) + .subqueryload(Role.permissions) + .options(joinedload(Permission.action), joinedload(Permission.resource)) + ) + .where(User.id == user.id) + ) + roles = user_query.roles + + map_fab_action_name_to_method_name = {v: k for k, v in MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()} + map_fab_action_name_to_method_name[ACTION_CAN_ACCESS_MENU] = "GET" + resources = set() + for role in roles: + for permission in role.permissions: + action = permission.action.name + if ( + action in map_fab_action_name_to_method_name + and map_fab_action_name_to_method_name[action] in methods + ): + resource = permission.resource.name + if resource == permissions.RESOURCE_DAG: + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + if resource.startswith(permissions.RESOURCE_DAG_PREFIX): + resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) + else: + resources.add(resource) + return { + dag.dag_id + for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) + } def get_security_manager_override_class(self) -> type: """Return the security manager override.""" @@ -270,14 +378,14 @@ class FabAuthManager(BaseAuthManager): def _is_authorized_dag( self, method: ResourceMethod, - dag_details: DagDetails | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user :meta private: @@ -286,9 +394,9 @@ class FabAuthManager(BaseAuthManager): if is_global_authorized: return True - if dag_details and dag_details.id: + if details and details.id: # Check whether the user has permissions to access a specific DAG - resource_dag_name = self._resource_name_for_dag(dag_details.id) + resource_dag_name = self._resource_name_for_dag(details.id) return self._is_authorized(method=method, resource_type=resource_dag_name, user=user) return False @@ -302,14 +410,14 @@ class FabAuthManager(BaseAuthManager): :meta private: """ - if method not in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME: + if method not in MAP_METHOD_NAME_TO_FAB_ACTION_NAME: raise AirflowException(f"Unknown method: {method}") - return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] + return MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] @staticmethod - def _get_fab_resource_type(dag_access_entity: DagAccessEntity): + def _get_fab_resource_types(dag_access_entity: DagAccessEntity) -> tuple[str, ...]: """ - Convert a DAG access entity to a FAB resource type. + Convert a DAG access entity to a tuple of FAB resource type. :param dag_access_entity: the DAG access entity @@ -361,8 +469,7 @@ class FabAuthManager(BaseAuthManager): :meta private: """ if "." in dag_id: - dm = self.security_manager.appbuilder.get_session.scalar( + return self.security_manager.appbuilder.get_session.scalar( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1) ) - return dm.root_dag_id or dm.dag_id return dag_id diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index 2e5bf313d9..cd5cb86804 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -25,7 +25,7 @@ import random import uuid import warnings from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Container, Iterable, Sequence import re2 from flask import flash, g, session @@ -42,13 +42,21 @@ from sqlalchemy import func, inspect, select from sqlalchemy.exc import MultipleResultsFound from werkzeug.security import generate_password_hash +from airflow.auth.managers.fab.fab_auth_manager import MAP_METHOD_NAME_TO_FAB_ACTION_NAME from airflow.auth.managers.fab.models import Action, Permission, RegisterUser, Resource, Role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.models import DagModel +from airflow.security import permissions +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.security_manager import AirflowSecurityManagerV2 from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.fab.models import User log = logging.getLogger(__name__) @@ -502,6 +510,91 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): log.error(const.LOGMSG_ERR_SEC_CREATE_DB, e) exit(1) + def get_readable_dags(self, user) -> Iterable[DagModel]: + """Get the DAGs readable by authenticated user.""" + warnings.warn( + "`get_readable_dags` has been deprecated. Please use `get_auth_manager().get_permitted_dag_ids` " + "instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) + + def get_editable_dags(self, user) -> Iterable[DagModel]: + """Get the DAGs editable by authenticated user.""" + warnings.warn( + "`get_editable_dags` has been deprecated. Please use `get_auth_manager().get_permitted_dag_ids` " + "instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) + + @provide_session + def get_accessible_dags( + self, + user_actions: Container[str] | None, + user, + session: Session = NEW_SESSION, + ) -> Iterable[DagModel]: + warnings.warn( + "`get_accessible_dags` has been deprecated. Please use " + "`get_auth_manager().get_permitted_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + + dag_ids = self.get_accessible_dag_ids(user, user_actions, session) + return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) + + @provide_session + def get_accessible_dag_ids( + self, + user, + user_actions: Container[str] | None = None, + session: Session = NEW_SESSION, + ) -> set[str]: + warnings.warn( + "`get_accessible_dag_ids` has been deprecated. Please use " + "`get_auth_manager().get_permitted_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + if not user_actions: + user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] + fab_action_name_to_method_name = {v: k for k, v in MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()} + user_methods: Container[ResourceMethod] = [ + fab_action_name_to_method_name[action] + for action in fab_action_name_to_method_name + if action in user_actions + ] + return get_auth_manager().get_permitted_dag_ids(user=user, methods=user_methods, session=session) + + @staticmethod + def get_readable_dag_ids(user=None) -> set[str]: + """Get the DAG IDs readable by authenticated user.""" + return get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user) + + @staticmethod + def get_editable_dag_ids(user=None) -> set[str]: + """Get the DAG IDs editable by authenticated user.""" + return get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=user) + + def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: + """Check if user has read or write access to some dags.""" + if dag_id and dag_id != "~": + root_dag_id = self._get_root_dag_id(dag_id) + return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) + + user = g.user + if action == permissions.ACTION_CAN_READ: + return any(self.get_readable_dag_ids(user)) + return any(self.get_editable_dag_ids(user)) + """ ----------- Role entity @@ -1071,6 +1164,31 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): log.debug("Token Get: %s", token) return token + def check_authorization( + self, + perms: Sequence[tuple[str, str]] | None = None, + dag_id: str | None = None, + ) -> bool: + """Checks that the logged in user has the specified permissions.""" + if not perms: + return True + + for perm in perms: + if perm in ( + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), + ): + can_access_all_dags = self.has_access(*perm) + if not can_access_all_dags: + action = perm[0] + if not self.can_access_some_dags(action, dag_id): + return False + elif not self.has_access(*perm): + return False + + return True + @staticmethod def _azure_parse_jwt(token): """ diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py index 51cdc59793..1f98ba72ce 100644 --- a/airflow/auth/managers/models/resource_details.py +++ b/airflow/auth/managers/models/resource_details.py @@ -21,18 +21,46 @@ from dataclasses import dataclass from enum import Enum +@dataclass +class ConfigurationDetails: + """Represents the details of a configuration.""" + + section: str | None = None + + @dataclass class ConnectionDetails: """Represents the details of a connection.""" - conn_id: str + conn_id: str | None = None @dataclass class DagDetails: """Represents the details of a DAG.""" - id: str + id: str | None = None + + +@dataclass +class DatasetDetails: + """Represents the details of a dataset.""" + + uri: str | None = None + + +@dataclass +class PoolDetails: + """Represents the details of a pool.""" + + name: str | None = None + + +@dataclass +class VariableDetails: + """Represents the details of a variable.""" + + key: str | None = None class DagAccessEntity(Enum): @@ -40,9 +68,11 @@ class DagAccessEntity(Enum): AUDIT_LOG = "AUDIT_LOG" CODE = "CODE" - DATASET = "DATASET" DEPENDENCIES = "DEPENDENCIES" + IMPORT_ERRORS = "IMPORT_ERRORS" RUN = "RUN" + TASK = "TASK" TASK_INSTANCE = "TASK_INSTANCE" TASK_LOGS = "TASK_LOGS" + WARNING = "WARNING" XCOM = "XCOM" diff --git a/airflow/www/auth.py b/airflow/www/auth.py index ffd80a117c..8fb6ffb435 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -35,7 +35,7 @@ from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.models import Connection + from airflow.models.connection import Connection T = TypeVar("T", bound=Callable) @@ -75,7 +75,7 @@ def _has_access_no_details(is_authorized_callback: Callable[[], bool]) -> Callab This works only for resources with no details. This function is used in some ``has_access_`` functions below. - :param is_authorized_callback: callback to execute to figure whether the user authorized to access + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access the resource? """ @@ -140,9 +140,7 @@ def has_access_connection(method: ResourceMethod) -> Callable[[T], T]: ] is_authorized = all( [ - get_auth_manager().is_authorized_connection( - method=method, connection_details=connection_details - ) + get_auth_manager().is_authorized_connection(method=method, details=connection_details) for connection_details in connections_details ] ) @@ -191,8 +189,8 @@ def has_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None is_authorized = get_auth_manager().is_authorized_dag( method=method, - dag_access_entity=access_entity, - dag_details=None if not dag_id else DagDetails(id=dag_id), + access_entity=access_entity, + details=None if not dag_id else DagDetails(id=dag_id), ) return _has_access( diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index ff5481dd46..95cd9b8c26 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -69,10 +69,12 @@ def init_jinja_globals(app): "git_version": git_version, "k8s_or_k8scelery_executor": IS_K8S_OR_K8SCELERY_EXECUTOR, "rest_api_enabled": False, - "auth_manager": get_auth_manager(), "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } + # Extra global specific to auth manager + extra_globals["auth_manager"] = get_auth_manager() + backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": extra_globals["rest_api_enabled"] = True diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 580191a9cb..065490f668 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -18,13 +18,13 @@ from __future__ import annotations import itertools import warnings -from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Collection, Iterable, Sequence from flask import g from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from airflow.auth.managers.fab.models import Permission, Resource, Role, User +from airflow.auth.managers.fab.models import Permission, Resource, Role from airflow.auth.managers.fab.views.permissions import ( ActionModelView, PermissionPairModelView, @@ -48,8 +48,6 @@ from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models import DagBag, DagModel from airflow.security import permissions from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.fab_security.sqla.manager import SecurityManager from airflow.www.utils import CustomSQLAInterface @@ -62,7 +60,8 @@ EXISTING_ROLES = { } if TYPE_CHECKING: - from sqlalchemy.orm import Session + + pass class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): @@ -269,126 +268,6 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): user = g.user return user.roles - def get_readable_dags(self, user) -> Iterable[DagModel]: - """Get the DAGs readable by authenticated user.""" - warnings.warn( - "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) - - def get_editable_dags(self, user) -> Iterable[DagModel]: - """Get the DAGs editable by authenticated user.""" - warnings.warn( - "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) - - @provide_session - def get_accessible_dags( - self, - user_actions: Container[str] | None, - user, - session: Session = NEW_SESSION, - ) -> Iterable[DagModel]: - warnings.warn( - "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=3, - ) - dag_ids = self.get_accessible_dag_ids(user, user_actions, session) - return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) - - def get_readable_dag_ids(self, user) -> set[str]: - """Get the DAG IDs readable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_READ]) - - def get_editable_dag_ids(self, user) -> set[str]: - """Get the DAG IDs editable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) - - @provide_session - def get_accessible_dag_ids( - self, - user, - user_actions: Container[str] | None = None, - session: Session = NEW_SESSION, - ) -> set[str]: - """Get readable or writable DAGs for user.""" - if not user_actions: - user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] - - if not get_auth_manager().is_logged_in(): - roles = user.roles - else: - if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( - permissions.ACTION_CAN_READ in user_actions and self.can_read_all_dags(user) - ): - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - user_query = session.scalar( - select(User) - .options( - joinedload(User.roles) - .subqueryload(Role.permissions) - .options(joinedload(Permission.action), joinedload(Permission.resource)) - ) - .where(User.id == user.id) - ) - roles = user_query.roles - - resources = set() - for role in roles: - for permission in role.permissions: - action = permission.action.name - if action in user_actions: - resource = permission.resource.name - if resource == permissions.RESOURCE_DAG: - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - if resource.startswith(permissions.RESOURCE_DAG_PREFIX): - resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) - else: - resources.add(resource) - return { - dag.dag_id - for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) - } - - def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: - """Check if user has read or write access to some dags.""" - if dag_id and dag_id != "~": - root_dag_id = self._get_root_dag_id(dag_id) - return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) - - user = g.user - if action == permissions.ACTION_CAN_READ: - return any(self.get_readable_dag_ids(user)) - return any(self.get_editable_dag_ids(user)) - - def can_read_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG read access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) - - def can_edit_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG edit access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) - - def can_delete_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG delete access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) - def prefixed_dag_id(self, dag_id: str) -> str: """Return the permission name for a DAG id.""" warnings.warn( @@ -430,36 +309,6 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): return False - def _has_role(self, role_name_or_list: Container, user) -> bool: - """Whether the user has this role name.""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any(r.name in role_name_or_list for r in user.roles) - - def has_all_dags_access(self, user) -> bool: - """ - Has all the dag access in any of the 3 cases. - - 1. Role needs to be in (Admin, Viewer, User, Op). - 2. Has can_read action on dags resource. - 3. Has can_edit action on dags resource. - """ - if not user: - user = g.user - return ( - self._has_role(["Admin", "Viewer", "Op", "User"], user) - or self.can_read_all_dags(user) - or self.can_edit_all_dags(user) - ) - - def can_edit_all_dags(self, user=None) -> bool: - """Has can_edit action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) - - def can_read_all_dags(self, user=None) -> bool: - """Has can_read action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) - def clean_perms(self) -> None: """FAB leaves faulty permissions that need to be cleaned up.""" self.log.debug("Cleaning faulty perms") @@ -740,22 +589,6 @@ class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): perms: Sequence[tuple[str, str]] | None = None, dag_id: str | None = None, ) -> bool: - """Check that the logged in user has the specified permissions.""" - if not perms: - return True - - for perm in perms: - if perm in ( - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ): - can_access_all_dags = self.has_access(*perm) - if not can_access_all_dags: - action = perm[0] - if not self.can_access_some_dags(action, dag_id): - return False - elif not self.has_access(*perm): - return False - - return True + raise NotImplementedError( + "The method 'check_authorization' is only available with the auth manager FabAuthManager" + ) diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index d324199d1f..40440d3fd6 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,16 +110,15 @@ {% if dag.parent_dag is defined and dag.parent_dag %} <span class="text-muted">SUBDAG:</span> {{ dag.dag_id }} {% else %} - {% set can_edit = appbuilder.sm.can_edit_dag(dag.dag_id) %} - {% if appbuilder.sm.can_edit_dag(dag.dag_id) %} + {% if can_edit_dag %} {% set switch_tooltip = 'Pause/Unpause DAG' %} {% else %} {% set switch_tooltip = 'DAG is Paused' if dag_is_paused else 'DAG is Active' %} {% endif %} - <label class="switch-label{{' disabled' if not can_edit else '' }} js-tooltip" title="{{ switch_tooltip }}"> + <label class="switch-label{{' disabled' if not can_edit_dag else '' }} js-tooltip" title="{{ switch_tooltip }}"> <input class="switch-input" id="pause_resume" data-dag-id="{{ dag.dag_id }}" type="checkbox"{{ " checked" if not dag_is_paused else "" }} - {{ " disabled" if not can_edit else "" }}> + {{ " disabled" if not can_edit_dag else "" }}> <span class="switch" aria-hidden="true"></span> </label> <span class="text-muted">DAG:</span> {{ dag.dag_id }} diff --git a/airflow/www/views.py b/airflow/www/views.py index 308578a11d..c675bd2b78 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -83,7 +83,7 @@ from airflow.api.common.mark_tasks import ( set_dag_run_state_to_success, set_state, ) -from airflow.auth.managers.models.resource_details import DagAccessEntity +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.compat.functools import cache from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.datasets import Dataset @@ -699,6 +699,12 @@ class AirflowBaseView(BaseView): # Add triggerer_job only if we need it if TriggererJobRunner.is_needed(): kwargs["triggerer_job"] = lazy_object_proxy.Proxy(TriggererJobRunner.most_recent_job) + + if "dag" in kwargs: + kwargs["can_edit_dag"] = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=kwargs["dag"].dag_id) + ) + return super().render_template( *args, # Cache this at most once per request, not for the lifetime of the view instance @@ -768,7 +774,7 @@ class Airflow(AirflowBaseView): end = start + dags_per_page # Get all the dag id the user could access - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) with create_session() as session: # read orm_dags from the db @@ -896,11 +902,9 @@ class Airflow(AirflowBaseView): .unique() .all() ) - user_permissions = g.user.perms - can_create_dag_run = ( - permissions.ACTION_CAN_CREATE, - permissions.RESOURCE_DAG_RUN, - ) in user_permissions + can_create_dag_run = get_auth_manager().is_authorized_dag( + method="POST", access_entity=DagAccessEntity.RUN, user=g.user + ) dataset_triggered_dag_ids = {dag.dag_id for dag in dags if dag.schedule_interval == "Dataset"} if dataset_triggered_dag_ids: @@ -911,9 +915,13 @@ class Airflow(AirflowBaseView): dataset_triggered_next_run_info = {} for dag in dags: - dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id, g.user) + dag.can_edit = get_auth_manager().is_authorized_dag( + method="PUT", details=DagDetails(id=dag.dag_id), user=g.user + ) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id, g.user) + dag.can_delete = get_auth_manager().is_authorized_dag( + method="DELETE", details=DagDetails(id=dag.dag_id), user=g.user + ) dagtags = session.execute(select(func.distinct(DagTag.name)).order_by(DagTag.name)).all() tags = [ @@ -925,7 +933,7 @@ class Airflow(AirflowBaseView): import_errors = select(errors.ImportError).order_by(errors.ImportError.id) - if (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG) not in user_permissions: + if not get_auth_manager().is_authorized_dag(method="GET"): # if the user doesn't have access to all DAGs, only display errors from visible DAGs import_errors = import_errors.join( DagModel, DagModel.fileloc == errors.ImportError.filename @@ -968,10 +976,9 @@ class Airflow(AirflowBaseView): # Second segment is a version marker that we don't need to show. yield segments[-1], table_name - if ( - permissions.ACTION_CAN_ACCESS_MENU, - permissions.RESOURCE_ADMIN_MENU, - ) in user_permissions and conf.getboolean("webserver", "warn_deployment_exposure"): + if get_auth_manager().is_authorized_configuration(method="GET", user=g.user) and conf.getboolean( + "webserver", "warn_deployment_exposure" + ): robots_file_access_count = ( select(Log) .where(Log.event == "robots") @@ -1057,11 +1064,10 @@ class Airflow(AirflowBaseView): ) @expose("/next_run_datasets_summary", methods=["POST"]) - @auth.has_access_dag("GET") @provide_session def next_run_datasets_summary(self, session: Session = NEW_SESSION): """Next run info for dataset triggered DAGs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1096,7 +1102,7 @@ class Airflow(AirflowBaseView): @provide_session def dag_stats(self, session: Session = NEW_SESSION): """Dag statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -1128,7 +1134,7 @@ class Airflow(AirflowBaseView): @provide_session def task_stats(self, session: Session = NEW_SESSION): """Task Statistics.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) if not allowed_dag_ids: return flask.json.jsonify({}) @@ -1227,7 +1233,7 @@ class Airflow(AirflowBaseView): @provide_session def last_dagruns(self, session: Session = NEW_SESSION): """Last DAG runs.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -2334,7 +2340,7 @@ class Airflow(AirflowBaseView): @provide_session def blocked(self, session: Session = NEW_SESSION): """Mark Dag Blocked.""" - allowed_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) # Filter by post parameters selected_dag_ids = {unquote(dag_id) for dag_id in request.form.getlist("dag_ids") if dag_id} @@ -3558,7 +3564,8 @@ class Airflow(AirflowBaseView): ) @expose("/object/next_run_datasets/<string:dag_id>") - @auth.has_access_dag("GET", DagAccessEntity.DATASET) + @auth.has_access_dag("GET", DagAccessEntity.RUN) + @auth.has_access_dataset("GET") def next_run_datasets(self, dag_id): """Return datasets necessary, and their status, for the next dag run.""" dag = get_airflow_app().dag_bag.get_dag(dag_id) @@ -3901,9 +3908,11 @@ class DagFilter(BaseFilter): """Filter using DagIDs.""" def apply(self, query, func): - if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): + if get_auth_manager().is_authorized_dag(method="GET", user=g.user): + return query + if get_auth_manager().is_authorized_dag(method="PUT", user=g.user): return query - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) return query.where(self.model.dag_id.in_(filter_dag_ids)) @@ -3953,7 +3962,7 @@ class AirflowPrivilegeVerifierModelView(AirflowModelView): @staticmethod def validate_dag_edit_access(item: DagRun | TaskInstance): """Validate whether the user has 'can_edit' access for this specific DAG.""" - if not get_airflow_app().appbuilder.sm.can_edit_dag(item.dag_id): + if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=item.dag_id)): raise AirflowException(f"Access denied for dag_id {item.dag_id}") def pre_add(self, item: DagRun | TaskInstance): @@ -3999,7 +4008,7 @@ def action_has_dag_edit_access(action_func: Callable) -> Callable: ) for dag_id in dag_ids: - if not get_airflow_app().appbuilder.sm.can_edit_dag(dag_id): + if not get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag_id)): flash(f"Access denied for dag_id {dag_id}", "danger") logging.warning("User %s tried to modify %s without having access.", g.user.username, dag_id) return redirect(self.get_default_url()) @@ -5694,7 +5703,6 @@ class TaskInstanceModelView(AirflowPrivilegeVerifierModelView): class AutocompleteView(AirflowBaseView): """View to provide autocomplete results.""" - @auth.has_access_dag("GET") @provide_session @expose("/dagmodel/autocomplete") def autocomplete(self, session: Session = NEW_SESSION): @@ -5728,7 +5736,7 @@ class AutocompleteView(AirflowBaseView): dag_ids_query = dag_ids_query.where(DagModel.is_paused) owners_query = owners_query.where(DagModel.is_paused) - filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + filter_dag_ids = get_auth_manager().get_permitted_dag_ids(user=g.user) dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) @@ -5813,9 +5821,9 @@ def add_user_permissions_to_dag(sender, template, context, **extra): permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN ) - dag.can_edit = get_airflow_app().appbuilder.sm.can_edit_dag(dag.dag_id) + dag.can_edit = get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag.dag_id)) dag.can_trigger = dag.can_edit and can_create_dag_run - dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id) + dag.can_delete = get_auth_manager().is_authorized_dag(method="DELETE", details=DagDetails(id=dag.dag_id)) context["dag"] = dag diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index ad154e57fe..d1e8af2fac 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -36,11 +36,26 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) + create_user( + app, # type:ignore + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore + ) + app.appbuilder.sm.sync_perm_for_dag( # type: ignore + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( # type: ignore + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore yield app delete_user(app, username="test") # type: ignore + delete_user(app, username="test_granular") # type: ignore delete_user(app, username="test_no_permissions") # type: ignore @@ -253,7 +268,7 @@ class TestGetEventLogs(TestEventLogEndpoint): for attr in ["dag_id", "task_id", "owner", "event"]: attr_value = f"TEST_{attr}_1".upper() response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} ) assert response.status_code == 200 assert {eventlog[attr] for eventlog in response.json["event_logs"]} == {attr_value} diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index a80175a575..7ed1329d9d 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -47,8 +47,7 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), ], ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index cf7716d644..9e175ab488 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -55,8 +55,6 @@ def configured_app(minimal_app_for_api): role_name="Test", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) @@ -65,8 +63,6 @@ def configured_app(minimal_app_for_api): username="test_granular_permissions", role_name="TestGranularDag", permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py index 85a13ad178..2e545457dd 100644 --- a/tests/auth/managers/fab/test_fab_auth_manager.py +++ b/tests/auth/managers/fab/test_fab_auth_manager.py @@ -253,12 +253,40 @@ class TestFabAuthManager: [(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_READ, RESOURCE_DAG_RUN)], True, ), - # With read permissions on a specific DAG + # Without read permissions on a specific DAG + ( + "GET", + DagAccessEntity.TASK_INSTANCE, + DagDetails(id="test_dag_id"), + [(ACTION_CAN_READ, RESOURCE_TASK_INSTANCE)], + False, + ), + # With read permissions on a specific DAG but not on the DAG run ( "GET", DagAccessEntity.TASK_INSTANCE, DagDetails(id="test_dag_id"), [(ACTION_CAN_READ, "DAG:test_dag_id"), (ACTION_CAN_READ, RESOURCE_TASK_INSTANCE)], + False, + ), + # With read permissions on a specific DAG but not on the DAG run + ( + "GET", + DagAccessEntity.TASK_INSTANCE, + DagDetails(id="test_dag_id"), + [ + (ACTION_CAN_READ, "DAG:test_dag_id"), + (ACTION_CAN_READ, RESOURCE_TASK_INSTANCE), + (ACTION_CAN_READ, RESOURCE_DAG_RUN), + ], + True, + ), + # With edit permissions on a specific DAG and read on the DAG access entity + ( + "DELETE", + DagAccessEntity.TASK, + DagDetails(id="test_dag_id"), + [(ACTION_CAN_EDIT, "DAG:test_dag_id"), (ACTION_CAN_DELETE, RESOURCE_TASK_INSTANCE)], True, ), # With edit permissions on a specific DAG and read on the DAG access entity @@ -293,7 +321,7 @@ class TestFabAuthManager: user = Mock() user.perms = user_permissions result = auth_manager.is_authorized_dag( - method=method, dag_access_entity=dag_access_entity, dag_details=dag_details, user=user + method=method, access_entity=dag_access_entity, details=dag_details, user=user ) assert result == expected_result diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 1ff8dcfbbf..416fa75e2a 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -27,7 +27,15 @@ from airflow.www.security_manager import AirflowSecurityManagerV2 if TYPE_CHECKING: from airflow.auth.managers.models.base_user import BaseUser - from airflow.auth.managers.models.resource_details import ConnectionDetails, DagAccessEntity, DagDetails + from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, + ) class EmptyAuthManager(BaseAuthManager): @@ -40,7 +48,13 @@ class EmptyAuthManager(BaseAuthManager): def get_user_id(self) -> str: raise NotImplementedError() - def is_authorized_configuration(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: raise NotImplementedError() def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: @@ -50,7 +64,7 @@ class EmptyAuthManager(BaseAuthManager): self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: raise NotImplementedError() @@ -59,16 +73,25 @@ class EmptyAuthManager(BaseAuthManager): self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: raise NotImplementedError() - def is_authorized_dataset(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: raise NotImplementedError() - def is_authorized_variable(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + raise NotImplementedError() + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: raise NotImplementedError() def is_authorized_website(self, *, user: BaseUser | None = None) -> bool: diff --git a/tests/www/test_security.py b/tests/www/test_security.py index b70aad536b..80a6d6765b 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -33,6 +33,7 @@ from sqlalchemy import Column, Date, Float, Integer, String from airflow.auth.managers.fab.fab_auth_manager import FabAuthManager from airflow.auth.managers.fab.models import User, assoc_permission_role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser +from airflow.auth.managers.models.resource_details import DagDetails from airflow.configuration import initialize_config from airflow.exceptions import AirflowException from airflow.models import DagModel @@ -41,6 +42,7 @@ from airflow.models.dag import DAG from airflow.security import permissions from airflow.www import app as application from airflow.www.auth import get_access_denied_message +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.utils import CustomSQLAInterface from tests.test_utils.api_connexion_utils import ( create_user, @@ -118,6 +120,24 @@ def _delete_dag_model(dag_model, session, security_manager): _delete_dag_permissions(dag_model.dag_id, security_manager) +def _can_read_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=dag_id), user=user) + + +def _can_edit_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="PUT", details=DagDetails(id=dag_id), user=user) + + +def _can_delete_dag(dag_id: str, user) -> bool: + return get_auth_manager().is_authorized_dag(method="DELETE", details=DagDetails(id=dag_id), user=user) + + +def _has_all_dags_access(user) -> bool: + return get_auth_manager().is_authorized_dag( + method="GET", user=user + ) or get_auth_manager().is_authorized_dag(method="PUT", user=user) + + @contextlib.contextmanager def _create_dag_model_context(dag_id, session, security_manager): dag = _create_dag_model(dag_id, session, security_manager) @@ -321,7 +341,7 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( with _create_dag_model_context("test_dag_id", session, security_manager): security_manager.sync_roles() - assert security_manager.get_accessible_dag_ids(user) == set() + assert get_auth_manager().get_permitted_dag_ids(user=user) == set() def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, security_manager, has_dag_perm): @@ -334,8 +354,8 @@ def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, se with _create_dag_model_context(dag_id, session, security_manager): security_manager.sync_roles() - assert security_manager.can_read_dag(dag_id, user) is False - assert security_manager.can_edit_dag(dag_id, user) is False + assert _can_read_dag(dag_id, user) is False + assert _can_edit_dag(dag_id, user) is False assert has_dag_perm(permissions.ACTION_CAN_READ, dag_id, user) is False assert has_dag_perm(permissions.ACTION_CAN_EDIT, dag_id, user) is False @@ -359,7 +379,7 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( security_manager.sync_roles() - assert security_manager.get_accessible_dag_ids(user) == set(test_dag_ids) + assert get_auth_manager().get_permitted_dag_ids(user=user) == set(test_dag_ids) def test_verify_anon_user_with_admin_role_has_access_to_each_dag( @@ -379,8 +399,8 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag( with _create_dag_model_context(dag_id, session, security_manager): security_manager.sync_roles() - assert security_manager.can_read_dag(dag_id, user) is True - assert security_manager.can_edit_dag(dag_id, user) is True + assert _can_read_dag(dag_id, user) is True + assert _can_edit_dag(dag_id, user) is True assert has_dag_perm(permissions.ACTION_CAN_READ, dag_id, user) is True assert has_dag_perm(permissions.ACTION_CAN_EDIT, dag_id, user) is True @@ -487,7 +507,7 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio dag_id, access_control={role_name: permission_action} ) - assert security_manager.get_accessible_dag_ids(user) == {"dag_id"} + assert get_auth_manager().get_permitted_dag_ids(user=user) == {"dag_id"} @patch.object(FabAuthManager, "is_logged_in") @@ -495,7 +515,7 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( mock_is_logged_in, app, security_manager, session ): # In this test case, - # get_readable_dag_ids() don't return DAGs to which the user has CAN_EDIT action + # get_permitted_dag_ids() don't return DAGs to which the user has CAN_EDIT action username = "Monsieur User" role_name = "MyRole1" permission_action = [permissions.ACTION_CAN_EDIT] @@ -518,7 +538,7 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( dag_id, access_control={role_name: permission_action} ) - assert security_manager.get_readable_dag_ids(user) == set() + assert get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user) == set() def test_has_access(security_manager): @@ -551,9 +571,9 @@ def test_sync_perm_for_dag_creates_permissions_for_specified_roles(app, security security_manager.sync_perm_for_dag( test_dag_id, access_control={test_role: {"can_read", "can_edit"}} ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_manager): @@ -581,18 +601,18 @@ def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_m ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert _can_delete_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control={test_role: {}}) - assert not security_manager.can_read_dag(test_dag_id, user) - assert not security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert not _can_read_dag(test_dag_id, user) + assert not _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_manager): @@ -621,18 +641,18 @@ def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_ma ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) - assert security_manager.can_delete_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) + assert _can_delete_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control={"other_role": {"can_read"}}) - assert not security_manager.can_read_dag(test_dag_id, user) - assert not security_manager.can_edit_dag(test_dag_id, user) - assert not security_manager.can_delete_dag(test_dag_id, user) + assert not _can_read_dag(test_dag_id, user) + assert not _can_edit_dag(test_dag_id, user) + assert not _can_delete_dag(test_dag_id, user) def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, security_manager): @@ -659,16 +679,16 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s ] ) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) # Need to clear cache on user perms user._perms = None security_manager.sync_perm_for_dag(test_dag_id, access_control=None) - assert security_manager.can_read_dag(test_dag_id, user) - assert security_manager.can_edit_dag(test_dag_id, user) + assert _can_read_dag(test_dag_id, user) + assert _can_edit_dag(test_dag_id, user) def test_has_all_dag_access(app, security_manager): @@ -679,7 +699,7 @@ def test_has_all_dag_access(app, security_manager): username="user", role_name=role_name, ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -688,7 +708,7 @@ def test_has_all_dag_access(app, security_manager): role_name="read_all", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -697,7 +717,7 @@ def test_has_all_dag_access(app, security_manager): role_name="edit_all", permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], ) as user: - assert security_manager.has_all_dags_access(user) + assert _has_all_dags_access(user) with app.app_context(): with create_user_scope( @@ -706,7 +726,7 @@ def test_has_all_dag_access(app, security_manager): role_name="nada", permissions=[], ) as user: - assert not security_manager.has_all_dags_access(user) + assert not _has_all_dags_access(user) def test_access_control_with_non_existent_role(security_manager): diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 90e37d5583..e7f48a9cb0 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -488,6 +488,7 @@ def user_all_dags_tis(acl_app): role_name="role_all_dags_tis", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), ], @@ -537,6 +538,7 @@ def user_dags_tis_logs(acl_app): role_name="role_dags_tis_logs", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), @@ -693,6 +695,7 @@ def user_all_dags_edit_tis(acl_app): role_name="role_all_dags_edit_tis", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), ], @@ -855,6 +858,8 @@ def user_dag_level_access_with_ti_edit(acl_app): permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_EDIT, permissions.resource_name_for_dag("example_bash_operator")), diff --git a/tests/www/views/test_views_decorators.py b/tests/www/views/test_views_decorators.py index 80eb588f29..227193aaf8 100644 --- a/tests/www/views/test_views_decorators.py +++ b/tests/www/views/test_views_decorators.py @@ -18,15 +18,13 @@ from __future__ import annotations import urllib.parse -from unittest import mock import pytest -from airflow.models import DagBag, DagRun, TaskInstance, Variable +from airflow.models import DagBag, Variable from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.types import DagRunType -from airflow.www import app from airflow.www.views import action_has_dag_edit_access from tests.test_utils.db import clear_db_runs, clear_db_variables from tests.test_utils.www import _check_last_log, _check_last_log_masked_variable, check_content_in_response @@ -187,46 +185,6 @@ def test_calendar(admin_client, dagruns): check_content_in_response(expected, resp) -@pytest.mark.parametrize( - "class_type, no_instances, no_unique_dags", - [ - (None, 0, 0), - (TaskInstance, 0, 0), - (TaskInstance, 1, 1), - (TaskInstance, 10, 1), - (TaskInstance, 10, 5), - (DagRun, 0, 0), - (DagRun, 1, 1), - (DagRun, 10, 1), - (DagRun, 10, 9), - ], -) -def test_action_has_dag_edit_access(create_task_instance, class_type, no_instances, no_unique_dags): - unique_dag_ids = [f"test_dag_id_{nr}" for nr in range(no_unique_dags)] - tis: list[TaskInstance] = [ - create_task_instance( - task_id=f"test_task_instance_{nr}", - execution_date=timezone.datetime(2021, 1, 1 + nr), - dag_id=unique_dag_ids[nr % len(unique_dag_ids)], - run_id=f"test_run_id_{nr}", - ) - for nr in range(no_instances) - ] - if class_type is None: - test_items = None - else: - test_items = tis if class_type == TaskInstance else [ti.get_dagrun() for ti in tis] - test_items = test_items[0] if len(test_items) == 1 else test_items - application = app.create_app(testing=True) - with application.app_context(): - with mock.patch.object(application.appbuilder.sm, "can_edit_dag") as mocked_can_edit: - mocked_can_edit.return_value = True - assert not isinstance(test_items, list) or len(test_items) == no_instances - assert some_view_action_which_requires_dag_edit_access(None, test_items) is True - assert mocked_can_edit.call_count == no_unique_dags - clear_db_runs() - - def test_action_has_dag_edit_access_exception(): with pytest.raises(ValueError): some_view_action_which_requires_dag_edit_access(None, "some_incorrect_value") diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 44ef85d96a..32dad8c431 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -687,6 +687,7 @@ def one_dag_perm_user_client(app): username=username, role_name="User with permission to access only one dag", permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE),