This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 47348ce66c Implement `filter_permitted_dag_ids` in AWS auth manager (#37666) 47348ce66c is described below commit 47348ce66c1f4a90aec87215b8a237c4bffcddca Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Fri Mar 8 14:01:45 2024 -0500 Implement `filter_permitted_dag_ids` in AWS auth manager (#37666) --- airflow/auth/managers/base_auth_manager.py | 23 ++++++- .../amazon/aws/auth_manager/aws_auth_manager.py | 75 ++++++++++++++++++---- .../aws/auth_manager/test_aws_auth_manager.py | 59 +++++++++++++++++ 3 files changed, 144 insertions(+), 13 deletions(-) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index a0176fc833..e378e5e10b 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -344,12 +344,31 @@ class BaseAuthManager(LoggingMixin): By default, reads all the DAGs and check individually if the user has permissions to access the DAG. Can lead to some poor performance. It is recommended to override this method in the auth manager implementation to provide a more efficient implementation. + + :param methods: whether filter readable or writable + :param user: the current user + :param session: the session + """ + dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, user=user) + + def filter_permitted_dag_ids( + self, + *, + dag_ids: set[str], + methods: Container[ResourceMethod] | None = None, + user=None, + ): + """ + Filter readable or writable DAGs for user. + + :param dag_ids: the list of DAG ids + :param methods: whether filter readable or writable + :param user: the current user """ if not methods: methods = ["PUT", "GET"] - dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) ): diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 0aa1fdb2c5..c17234c047 100644 --- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -17,8 +17,9 @@ from __future__ import annotations import argparse +from collections import defaultdict from functools import cached_property -from typing import TYPE_CHECKING, Sequence, cast +from typing import TYPE_CHECKING, Container, Sequence, cast from flask import session, url_for @@ -443,6 +444,60 @@ class AwsAuthManager(BaseAuthManager): ] return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user()) + def filter_permitted_dag_ids( + self, + *, + dag_ids: set[str], + methods: Container[ResourceMethod] | None = None, + user=None, + ): + """ + Filter readable or writable DAGs for user. + + :param dag_ids: the list of DAG ids + :param methods: whether filter readable or writable + :param user: the current user + """ + if not methods: + methods = ["PUT", "GET"] + + if not user: + user = self.get_user() + + requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict) + requests_list: list[IsAuthorizedRequest] = [] + for dag_id in dag_ids: + for method in ["GET", "PUT"]: + if method in methods: + request: IsAuthorizedRequest = { + "method": cast(ResourceMethod, method), + "entity_type": AvpEntities.DAG, + "entity_id": dag_id, + } + requests[dag_id][cast(ResourceMethod, method)] = request + requests_list.append(request) + + batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results( + requests=requests_list, user=user + ) + + def _has_access_to_dag(request: IsAuthorizedRequest): + result = self.avp_facade.get_batch_is_authorized_single_result( + batch_is_authorized_results=batch_is_authorized_results, request=request, user=user + ) + return result["decision"] == "ALLOW" + + return { + dag_id + for dag_id in dag_ids + if ( + "GET" in methods + and _has_access_to_dag(requests[dag_id]["GET"]) + or "PUT" in methods + and _has_access_to_dag(requests[dag_id]["PUT"]) + ) + } + def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]: """ Filter menu items based on user permissions. @@ -465,19 +520,25 @@ class AwsAuthManager(BaseAuthManager): requests=list(requests.values()), user=user ) + def _has_access_to_menu_item(request: IsAuthorizedRequest): + result = self.avp_facade.get_batch_is_authorized_single_result( + batch_is_authorized_results=batch_is_authorized_results, request=request, user=user + ) + return result["decision"] == "ALLOW" + accessible_items = [] for menu_item in menu_items: if menu_item.childs: accessible_children = [] for child in menu_item.childs: - if self._has_access_to_menu_item(batch_is_authorized_results, requests[child.name], user): + if _has_access_to_menu_item(requests[child.name]): accessible_children.append(child) menu_item.childs = accessible_children # Display the menu if the user has access to at least one sub item if len(accessible_children) > 0: accessible_items.append(menu_item) - elif self._has_access_to_menu_item(batch_is_authorized_results, requests[menu_item.name], user): + elif _has_access_to_menu_item(requests[menu_item.name]): accessible_items.append(menu_item) return accessible_items @@ -511,14 +572,6 @@ class AwsAuthManager(BaseAuthManager): else: raise AirflowException(f"Unknown resource name {fab_resource_name}") - def _has_access_to_menu_item( - self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser - ): - result = self.avp_facade.get_batch_is_authorized_single_result( - batch_is_authorized_results=batch_is_authorized_results, request=request, user=user - ) - return result["decision"] == "ALLOW" - def get_parser() -> argparse.ArgumentParser: """Generate documentation; used by Sphinx argparse.""" diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index daa21f21c1..9b654199c3 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -642,6 +642,65 @@ class TestAwsAuthManager: ] ) + @pytest.mark.parametrize( + "methods, user", + [ + (None, None), + (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", groups=[])), + ], + ) + @patch.object(AwsAuthManager, "get_user") + def test_filter_permitted_dag_ids(self, mock_get_user, methods, user, auth_manager, test_user): + dag_ids = {"dag_1", "dag_2"} + batch_is_authorized_output = [ + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, + }, + "decision": "DENY", + }, + { + "request": { + "principal": {"entityType": "Airflow::User", "entityId": "test_user_id"}, + "action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"}, + "resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"}, + }, + "decision": "ALLOW", + }, + ] + auth_manager.avp_facade.get_batch_is_authorized_results = Mock( + return_value=batch_is_authorized_output + ) + + mock_get_user.return_value = test_user + + result = auth_manager.filter_permitted_dag_ids( + dag_ids=dag_ids, + methods=methods, + user=user, + ) + + auth_manager.avp_facade.get_batch_is_authorized_results.assert_called() + assert result == {"dag_2"} + @patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for") def test_get_url_login(self, mock_url_for, auth_manager): auth_manager.get_url_login()