This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f36d17ce75 Implement `filter_permitted_menu_items` in AWS auth manager 
(#37627)
f36d17ce75 is described below

commit f36d17ce757e5c39ca4dd35a637a090ef6105744
Author: Vincent <[email protected]>
AuthorDate: Fri Feb 23 14:56:09 2024 -0500

    Implement `filter_permitted_menu_items` in AWS auth manager (#37627)
---
 airflow/auth/managers/base_auth_manager.py         |   2 +-
 .../amazon/aws/auth_manager/avp/facade.py          | 134 +++++++++----
 .../amazon/aws/auth_manager/aws_auth_manager.py    | 213 ++++++++++++++++++++-
 airflow/www/templates/appbuilder/navbar_menu.html  |   2 +-
 docs/apache-airflow/core-concepts/auth-manager.rst |   2 +-
 newsfragments/37627.significant.rst                |   1 +
 tests/auth/managers/test_base_auth_manager.py      |   4 +-
 .../amazon/aws/auth_manager/avp/test_facade.py     |  69 +++++++
 .../aws/auth_manager/test_aws_auth_manager.py      | 133 +++++++++++++
 9 files changed, 517 insertions(+), 43 deletions(-)

diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index b9dd459657..88164b644e 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -367,7 +367,7 @@ class BaseAuthManager(LoggingMixin):
             if _is_permitted_dag_id("GET", methods, dag_id) or 
_is_permitted_dag_id("PUT", methods, dag_id)
         }
 
-    def get_permitted_menu_items(self, menu_items: list[MenuItem]) -> 
list[MenuItem]:
+    def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> 
list[MenuItem]:
         """
         Filter menu items based on user permissions.
 
diff --git a/airflow/providers/amazon/aws/auth_manager/avp/facade.py 
b/airflow/providers/amazon/aws/auth_manager/avp/facade.py
index c13233b54e..18231b3d0f 100644
--- a/airflow/providers/amazon/aws/auth_manager/avp/facade.py
+++ b/airflow/providers/amazon/aws/auth_manager/avp/facade.py
@@ -37,6 +37,11 @@ if TYPE_CHECKING:
     from airflow.providers.amazon.aws.auth_manager.user import 
AwsAuthManagerUser
 
 
+# Amazon Verified Permissions allows only up to 30 requests per 
batch_is_authorized call. See
+# 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/verifiedpermissions/client/batch_is_authorized.html
+NB_REQUESTS_PER_BATCH = 30
+
+
 class IsAuthorizedRequest(TypedDict, total=False):
     """Represent the parameters of ``is_authorized`` method in AVP facade."""
 
@@ -125,62 +130,97 @@ class 
AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
 
         return resp["decision"] == "ALLOW"
 
-    def batch_is_authorized(
+    def get_batch_is_authorized_results(
         self,
         *,
         requests: Sequence[IsAuthorizedRequest],
-        user: AwsAuthManagerUser | None,
-    ) -> bool:
+        user: AwsAuthManagerUser,
+    ) -> list[dict]:
         """
         Make a batch authorization decision against Amazon Verified 
Permissions.
 
-        Check whether the user has permissions to access given resources.
+        Return a list of results for each request.
 
         :param requests: the list of requests containing the method, the 
entity_type and the entity ID
         :param user: the user
         """
-        if user is None:
-            return False
-
         entity_list = self._get_user_role_entities(user)
 
         self.log.debug("Making batch authorization request for user=%s, 
requests=%s", user.get_id(), requests)
 
-        avp_requests = [
-            prune_dict(
-                {
-                    "principal": {"entityType": 
get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
-                    "action": {
-                        "actionType": get_entity_type(AvpEntities.ACTION),
-                        "actionId": get_action_id(request["entity_type"], 
request["method"]),
-                    },
-                    "resource": {
-                        "entityType": get_entity_type(request["entity_type"]),
-                        "entityId": request.get("entity_id", "*"),
-                    },
-                    "context": self._build_context(request.get("context")),
-                }
-            )
-            for request in requests
+        avp_requests = [self._build_is_authorized_request_payload(request, 
user) for request in requests]
+        avp_requests_chunks = [
+            avp_requests[i : i + NB_REQUESTS_PER_BATCH]
+            for i in range(0, len(avp_requests), NB_REQUESTS_PER_BATCH)
         ]
 
-        resp = self.avp_client.batch_is_authorized(
-            policyStoreId=self.avp_policy_store_id,
-            requests=avp_requests,
-            entities={"entityList": entity_list},
-        )
+        results = []
+        for avp_requests in avp_requests_chunks:
+            resp = self.avp_client.batch_is_authorized(
+                policyStoreId=self.avp_policy_store_id,
+                requests=avp_requests,
+                entities={"entityList": entity_list},
+            )
 
-        self.log.debug("Authorization response: %s", resp)
+            self.log.debug("Authorization response: %s", resp)
 
-        has_errors = any(len(result.get("errors", [])) > 0 for result in 
resp["results"])
+            has_errors = any(len(result.get("errors", [])) > 0 for result in 
resp["results"])
 
-        if has_errors:
-            self.log.error(
-                "Error occurred while making a batch authorization decision. 
Result: %s", resp["results"]
-            )
-            raise AirflowException("Error occurred while making a batch 
authorization decision.")
+            if has_errors:
+                self.log.error(
+                    "Error occurred while making a batch authorization 
decision. Result: %s", resp["results"]
+                )
+                raise AirflowException("Error occurred while making a batch 
authorization decision.")
+
+            results.extend(resp["results"])
+
+        return results
+
+    def batch_is_authorized(
+        self,
+        *,
+        requests: Sequence[IsAuthorizedRequest],
+        user: AwsAuthManagerUser | None,
+    ) -> bool:
+        """
+        Make a batch authorization decision against Amazon Verified 
Permissions.
+
+        Check whether the user has permissions to access all resources.
+
+        :param requests: the list of requests containing the method, the 
entity_type and the entity ID
+        :param user: the user
+        """
+        if user is None:
+            return False
+        results = self.get_batch_is_authorized_results(requests=requests, 
user=user)
+        return all(result["decision"] == "ALLOW" for result in results)
+
+    def get_batch_is_authorized_single_result(
+        self,
+        *,
+        batch_is_authorized_results: list[dict],
+        request: IsAuthorizedRequest,
+        user: AwsAuthManagerUser,
+    ) -> dict:
+        """
+        Get a specific authorization result from the output of 
``get_batch_is_authorized_results``.
 
-        return all(result["decision"] == "ALLOW" for result in resp["results"])
+        :param batch_is_authorized_results: the response from the 
``batch_is_authorized`` API
+        :param request: the request information. Used to find the result in 
the response.
+        :param user: the user
+        """
+        request_payload = self._build_is_authorized_request_payload(request, 
user)
+
+        for result in batch_is_authorized_results:
+            if result["request"] == request_payload:
+                return result
+
+        self.log.error(
+            "Could not find the authorization result for request %s in results 
%s.",
+            request_payload,
+            batch_is_authorized_results,
+        )
+        raise AirflowException("Could not find the authorization result.")
 
     @staticmethod
     def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
@@ -205,3 +245,25 @@ class 
AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
         return {
             "contextMap": context,
         }
+
+    def _build_is_authorized_request_payload(self, request: 
IsAuthorizedRequest, user: AwsAuthManagerUser):
+        """
+        Build a payload of an individual authorization request that could be 
sent through the ``batch_is_authorized`` API.
+
+        :param request: the request information
+        :param user: the user
+        """
+        return prune_dict(
+            {
+                "principal": {"entityType": get_entity_type(AvpEntities.USER), 
"entityId": user.get_id()},
+                "action": {
+                    "actionType": get_entity_type(AvpEntities.ACTION),
+                    "actionId": get_action_id(request["entity_type"], 
request["method"]),
+                },
+                "resource": {
+                    "entityType": get_entity_type(request["entity_type"]),
+                    "entityId": request.get("entity_id", "*"),
+                },
+                "context": self._build_context(request.get("context")),
+            }
+        )
diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py 
b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 5ed09d3810..0aa1fdb2c5 100644
--- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -24,7 +24,7 @@ from flask import session, url_for
 
 from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
 from airflow.configuration import conf
-from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.exceptions import AirflowException, 
AirflowOptionalProviderFeatureException
 from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
 from airflow.providers.amazon.aws.auth_manager.avp.facade import (
     AwsAuthManagerAmazonVerifiedPermissionsFacade,
@@ -40,10 +40,33 @@ from airflow.providers.amazon.aws.auth_manager.constants 
import (
 from 
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
 import (
     AwsSecurityManagerOverride,
 )
+from airflow.security.permissions import (
+    RESOURCE_AUDIT_LOG,
+    RESOURCE_CLUSTER_ACTIVITY,
+    RESOURCE_CONFIG,
+    RESOURCE_CONNECTION,
+    RESOURCE_DAG,
+    RESOURCE_DAG_CODE,
+    RESOURCE_DAG_DEPENDENCIES,
+    RESOURCE_DAG_RUN,
+    RESOURCE_DATASET,
+    RESOURCE_DOCS,
+    RESOURCE_JOB,
+    RESOURCE_PLUGIN,
+    RESOURCE_POOL,
+    RESOURCE_PROVIDER,
+    RESOURCE_SLA_MISS,
+    RESOURCE_TASK_INSTANCE,
+    RESOURCE_TASK_RESCHEDULE,
+    RESOURCE_TRIGGER,
+    RESOURCE_VARIABLE,
+    RESOURCE_XCOM,
+)
 
 try:
     from airflow.auth.managers.base_auth_manager import BaseAuthManager, 
ResourceMethod
     from airflow.auth.managers.models.resource_details import (
+        AccessView,
         ConnectionDetails,
         DagAccessEntity,
         DagDetails,
@@ -56,6 +79,8 @@ except ImportError:
     )
 
 if TYPE_CHECKING:
+    from flask_appbuilder.menu import MenuItem
+
     from airflow.auth.managers.models.base_user import BaseUser
     from airflow.auth.managers.models.batch_apis import (
         IsAuthorizedConnectionRequest,
@@ -64,7 +89,6 @@ if TYPE_CHECKING:
         IsAuthorizedVariableRequest,
     )
     from airflow.auth.managers.models.resource_details import (
-        AccessView,
         ConfigurationDetails,
         DatasetDetails,
     )
@@ -72,6 +96,136 @@ if TYPE_CHECKING:
     from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
 
 
+_MENU_ITEM_REQUESTS: dict[str, IsAuthorizedRequest] = {
+    RESOURCE_AUDIT_LOG: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.AUDIT_LOG.value,
+            },
+        },
+    },
+    RESOURCE_CLUSTER_ACTIVITY: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.CLUSTER_ACTIVITY.value,
+    },
+    RESOURCE_CONFIG: {
+        "method": "GET",
+        "entity_type": AvpEntities.CONFIGURATION,
+    },
+    RESOURCE_CONNECTION: {
+        "method": "GET",
+        "entity_type": AvpEntities.CONNECTION,
+    },
+    RESOURCE_DAG: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+    },
+    RESOURCE_DAG_CODE: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.CODE.value,
+            },
+        },
+    },
+    RESOURCE_DAG_DEPENDENCIES: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.DEPENDENCIES.value,
+            },
+        },
+    },
+    RESOURCE_DAG_RUN: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.RUN.value,
+            },
+        },
+    },
+    RESOURCE_DATASET: {
+        "method": "GET",
+        "entity_type": AvpEntities.DATASET,
+    },
+    RESOURCE_DOCS: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.DOCS.value,
+    },
+    RESOURCE_PLUGIN: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.PLUGINS.value,
+    },
+    RESOURCE_JOB: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.JOBS.value,
+    },
+    RESOURCE_POOL: {
+        "method": "GET",
+        "entity_type": AvpEntities.POOL,
+    },
+    RESOURCE_PROVIDER: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.PROVIDERS.value,
+    },
+    RESOURCE_SLA_MISS: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.SLA_MISS.value,
+            },
+        },
+    },
+    RESOURCE_TASK_INSTANCE: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.TASK_INSTANCE.value,
+            },
+        },
+    },
+    RESOURCE_TASK_RESCHEDULE: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.TASK_RESCHEDULE.value,
+            },
+        },
+    },
+    RESOURCE_TRIGGER: {
+        "method": "GET",
+        "entity_type": AvpEntities.VIEW,
+        "entity_id": AccessView.TRIGGERS.value,
+    },
+    RESOURCE_VARIABLE: {
+        "method": "GET",
+        "entity_type": AvpEntities.VARIABLE,
+    },
+    RESOURCE_XCOM: {
+        "method": "GET",
+        "entity_type": AvpEntities.DAG,
+        "context": {
+            "dag_entity": {
+                "string": DagAccessEntity.XCOM.value,
+            },
+        },
+    },
+}
+
+
 class AwsAuthManager(BaseAuthManager):
     """
     AWS auth manager.
@@ -289,6 +443,45 @@ class AwsAuthManager(BaseAuthManager):
         ]
         return self.avp_facade.batch_is_authorized(requests=facade_requests, 
user=self.get_user())
 
+    def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> 
list[MenuItem]:
+        """
+        Filter menu items based on user permissions.
+
+        :param menu_items: list of all menu items
+        """
+        user = self.get_user()
+        if not user:
+            return []
+
+        requests: dict[str, IsAuthorizedRequest] = {}
+        for menu_item in menu_items:
+            if menu_item.childs:
+                for child in menu_item.childs:
+                    requests[child.name] = 
self._get_menu_item_request(child.name)
+            else:
+                requests[menu_item.name] = 
self._get_menu_item_request(menu_item.name)
+
+        batch_is_authorized_results = 
self.avp_facade.get_batch_is_authorized_results(
+            requests=list(requests.values()), user=user
+        )
+
+        accessible_items = []
+        for menu_item in menu_items:
+            if menu_item.childs:
+                accessible_children = []
+                for child in menu_item.childs:
+                    if 
self._has_access_to_menu_item(batch_is_authorized_results, 
requests[child.name], user):
+                        accessible_children.append(child)
+                menu_item.childs = accessible_children
+
+                # Display the menu if the user has access to at least one sub 
item
+                if len(accessible_children) > 0:
+                    accessible_items.append(menu_item)
+            elif self._has_access_to_menu_item(batch_is_authorized_results, 
requests[menu_item.name], user):
+                accessible_items.append(menu_item)
+
+        return accessible_items
+
     def get_url_login(self, **kwargs) -> str:
         return url_for("AwsAuthManagerAuthenticationViews.login")
 
@@ -310,6 +503,22 @@ class AwsAuthManager(BaseAuthManager):
             ),
         ]
 
+    @staticmethod
+    def _get_menu_item_request(fab_resource_name: str) -> IsAuthorizedRequest:
+        menu_item_request = _MENU_ITEM_REQUESTS.get(fab_resource_name)
+        if menu_item_request:
+            return menu_item_request
+        else:
+            raise AirflowException(f"Unknown resource name 
{fab_resource_name}")
+
+    def _has_access_to_menu_item(
+        self, batch_is_authorized_results: list[dict], request: 
IsAuthorizedRequest, user: AwsAuthManagerUser
+    ):
+        result = self.avp_facade.get_batch_is_authorized_single_result(
+            batch_is_authorized_results=batch_is_authorized_results, 
request=request, user=user
+        )
+        return result["decision"] == "ALLOW"
+
 
 def get_parser() -> argparse.ArgumentParser:
     """Generate documentation; used by Sphinx argparse."""
diff --git a/airflow/www/templates/appbuilder/navbar_menu.html 
b/airflow/www/templates/appbuilder/navbar_menu.html
index 72855df270..d7abc8472f 100644
--- a/airflow/www/templates/appbuilder/navbar_menu.html
+++ b/airflow/www/templates/appbuilder/navbar_menu.html
@@ -21,7 +21,7 @@
   <a href="{{item.get_url()}}">{{_(item.label)}}</a>
 {% endmacro %}
 
-{% for item1 in auth_manager.get_permitted_menu_items(menu.get_list()) %}
+{% for item1 in auth_manager.filter_permitted_menu_items(menu.get_list()) %}
   {% if item1 %}
     {% if item1.childs %}
       <li class="dropdown">
diff --git a/docs/apache-airflow/core-concepts/auth-manager.rst 
b/docs/apache-airflow/core-concepts/auth-manager.rst
index bb54795f18..4e3446acaa 100644
--- a/docs/apache-airflow/core-concepts/auth-manager.rst
+++ b/docs/apache-airflow/core-concepts/auth-manager.rst
@@ -124,7 +124,7 @@ The following methods aren't required to override to have a 
functional Airflow a
 * ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If 
not overridden, it will call ``is_authorized_pool`` for every single item.
 * ``batch_is_authorized_variable``: Batch version of 
``is_authorized_variable``. If not overridden, it will call 
``is_authorized_variable`` for every single item.
 * ``get_permitted_dag_ids``: Return the list of DAG IDs the user has access 
to.  If not overridden, it will call ``is_authorized_dag`` for every single DAG 
available in the environment.
-* ``get_permitted_menu_items``: Return the menu items the user has access to.  
If not overridden, it will call ``has_access`` in 
:class:`~airflow.www.security_manager.AirflowSecurityManagerV2` for every 
single menu item.
+* ``filter_permitted_menu_items``: Return the menu items the user has access 
to.  If not overridden, it will call ``has_access`` in 
:class:`~airflow.www.security_manager.AirflowSecurityManagerV2` for every 
single menu item.
 
 CLI
 ^^^
diff --git a/newsfragments/37627.significant.rst 
b/newsfragments/37627.significant.rst
new file mode 100644
index 0000000000..886b3e25a0
--- /dev/null
+++ b/newsfragments/37627.significant.rst
@@ -0,0 +1 @@
+The method ``get_permitted_menu_items`` in ``BaseAuthManager`` has been 
renamed ``filter_permitted_menu_items``
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index 9fac23a062..d05bb50dd4 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -299,7 +299,7 @@ class TestBaseAuthManager:
         assert result == expected
 
     @patch.object(EmptyAuthManager, "security_manager")
-    def test_get_permitted_menu_items(self, mock_security_manager, 
auth_manager):
+    def test_filter_permitted_menu_items(self, mock_security_manager, 
auth_manager):
         mock_security_manager.has_access.side_effect = [True, False, True, 
True, False]
 
         menu = Menu()
@@ -309,7 +309,7 @@ class TestBaseAuthManager:
         menu.add_link("item3.1", category="item3")
         menu.add_link("item3.2", category="item3")
 
-        result = auth_manager.get_permitted_menu_items(menu.get_list())
+        result = auth_manager.filter_permitted_menu_items(menu.get_list())
 
         assert len(result) == 2
         assert result[0].name == "item1"
diff --git a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py 
b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
index fca4961dfe..437f731b9f 100644
--- a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
+++ b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
@@ -280,3 +280,72 @@ class TestAwsAuthManagerAmazonVerifiedPermissionsFacade:
                     ],
                     user=test_user,
                 )
+
+    def test_get_batch_is_authorized_single_result_successful(self, facade):
+        single_result = {
+            "request": {
+                "principal": {"entityType": "Airflow::User", "entityId": 
"test_user"},
+                "action": {"actionType": "Airflow::Action", "actionId": 
"Connection.GET"},
+                "resource": {"entityType": "Airflow::Connection", "entityId": 
"*"},
+            },
+            "decision": "ALLOW",
+        }
+
+        with conf_vars(
+            {
+                ("aws_auth_manager", "avp_policy_store_id"): 
AVP_POLICY_STORE_ID,
+            }
+        ):
+            result = facade.get_batch_is_authorized_single_result(
+                batch_is_authorized_results=[
+                    {
+                        "request": {
+                            "principal": {"entityType": "Airflow::User", 
"entityId": "test_user"},
+                            "action": {"actionType": "Airflow::Action", 
"actionId": "Variable.GET"},
+                            "resource": {"entityType": "Airflow::Variable", 
"entityId": "*"},
+                        },
+                        "decision": "ALLOW",
+                    },
+                    single_result,
+                ],
+                request={
+                    "method": "GET",
+                    "entity_type": AvpEntities.CONNECTION,
+                },
+                user=test_user,
+            )
+
+        assert result == single_result
+
+    def test_get_batch_is_authorized_single_result_unsuccessful(self, facade):
+        with conf_vars(
+            {
+                ("aws_auth_manager", "avp_policy_store_id"): 
AVP_POLICY_STORE_ID,
+            }
+        ):
+            with pytest.raises(AirflowException, match="Could not find the 
authorization result."):
+                facade.get_batch_is_authorized_single_result(
+                    batch_is_authorized_results=[
+                        {
+                            "request": {
+                                "principal": {"entityType": "Airflow::User", 
"entityId": "test_user"},
+                                "action": {"actionType": "Airflow::Action", 
"actionId": "Variable.GET"},
+                                "resource": {"entityType": 
"Airflow::Variable", "entityId": "*"},
+                            },
+                            "decision": "ALLOW",
+                        },
+                        {
+                            "request": {
+                                "principal": {"entityType": "Airflow::User", 
"entityId": "test_user"},
+                                "action": {"actionType": "Airflow::Action", 
"actionId": "Variable.POST"},
+                                "resource": {"entityType": 
"Airflow::Variable", "entityId": "*"},
+                            },
+                            "decision": "ALLOW",
+                        },
+                    ],
+                    request={
+                        "method": "GET",
+                        "entity_type": AvpEntities.CONNECTION,
+                    },
+                    user=test_user,
+                )
diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py 
b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
index df4c45255f..daa21f21c1 100644
--- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
+++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -21,6 +21,7 @@ from unittest.mock import ANY, Mock, patch
 
 import pytest
 from flask import Flask, session
+from flask_appbuilder.menu import MenuItem
 
 from airflow.auth.managers.models.resource_details import (
     AccessView,
@@ -32,12 +33,20 @@ from airflow.auth.managers.models.resource_details import (
     PoolDetails,
     VariableDetails,
 )
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
 from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import 
AwsAuthManager
 from 
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
 import (
     AwsSecurityManagerOverride,
 )
 from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
+from airflow.security.permissions import (
+    RESOURCE_AUDIT_LOG,
+    RESOURCE_CLUSTER_ACTIVITY,
+    RESOURCE_CONNECTION,
+    RESOURCE_DATASET,
+    RESOURCE_VARIABLE,
+)
 from airflow.www.extensions.init_appbuilder import init_appbuilder
 from tests.test_utils.config import conf_vars
 
@@ -509,6 +518,130 @@ class TestAwsAuthManager:
         )
         assert result
 
+    @patch.object(AwsAuthManager, "get_user")
+    def test_filter_permitted_menu_items(self, mock_get_user, auth_manager, 
test_user):
+        batch_is_authorized_output = [
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Connection.GET"},
+                    "resource": {"entityType": "Airflow::Connection", 
"entityId": "*"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Variable.GET"},
+                    "resource": {"entityType": "Airflow::Variable", 
"entityId": "*"},
+                },
+                "decision": "ALLOW",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dataset.GET"},
+                    "resource": {"entityType": "Airflow::Dataset", "entityId": 
"*"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"View.GET"},
+                    "resource": {"entityType": "Airflow::View", "entityId": 
"CLUSTER_ACTIVITY"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"*"},
+                    "context": {
+                        "contextMap": {
+                            "dag_entity": {
+                                "string": "AUDIT_LOG",
+                            }
+                        }
+                    },
+                },
+                "decision": "ALLOW",
+            },
+        ]
+        auth_manager.avp_facade.get_batch_is_authorized_results = Mock(
+            return_value=batch_is_authorized_output
+        )
+
+        mock_get_user.return_value = test_user
+
+        result = auth_manager.filter_permitted_menu_items(
+            [
+                MenuItem("Category1", childs=[MenuItem(RESOURCE_CONNECTION), 
MenuItem(RESOURCE_VARIABLE)]),
+                MenuItem("Category2", childs=[MenuItem(RESOURCE_DATASET)]),
+                MenuItem(RESOURCE_CLUSTER_ACTIVITY),
+                MenuItem(RESOURCE_AUDIT_LOG),
+            ]
+        )
+
+        
auth_manager.avp_facade.get_batch_is_authorized_results.assert_called_once_with(
+            requests=[
+                {
+                    "method": "GET",
+                    "entity_type": AvpEntities.CONNECTION,
+                },
+                {
+                    "method": "GET",
+                    "entity_type": AvpEntities.VARIABLE,
+                },
+                {
+                    "method": "GET",
+                    "entity_type": AvpEntities.DATASET,
+                },
+                {
+                    "method": "GET",
+                    "entity_type": AvpEntities.VIEW,
+                    "entity_id": AccessView.CLUSTER_ACTIVITY.value,
+                },
+                {
+                    "method": "GET",
+                    "entity_type": AvpEntities.DAG,
+                    "context": {
+                        "dag_entity": {
+                            "string": DagAccessEntity.AUDIT_LOG.value,
+                        },
+                    },
+                },
+            ],
+            user=test_user,
+        )
+        assert len(result) == 2
+        assert result[0].name == "Category1"
+        assert len(result[0].childs) == 1
+        assert result[0].childs[0].name == RESOURCE_VARIABLE
+        assert result[1].name == RESOURCE_AUDIT_LOG
+
+    @patch.object(AwsAuthManager, "get_user")
+    def test_filter_permitted_menu_items_logged_out(self, mock_get_user, 
auth_manager):
+        mock_get_user.return_value = None
+        result = auth_manager.filter_permitted_menu_items(
+            [
+                MenuItem(RESOURCE_AUDIT_LOG),
+            ]
+        )
+
+        assert result == []
+
+    @patch.object(AwsAuthManager, "get_user")
+    def test_filter_permitted_menu_items_wrong_menu_item(self, mock_get_user, 
auth_manager, test_user):
+        mock_get_user.return_value = test_user
+        with pytest.raises(AirflowException, match="Unknown resource name"):
+            auth_manager.filter_permitted_menu_items(
+                [
+                    MenuItem("Test"),
+                ]
+            )
+
     
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
     def test_get_url_login(self, mock_url_for, auth_manager):
         auth_manager.get_url_login()

Reply via email to