This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f36d17ce75 Implement `filter_permitted_menu_items` in AWS auth manager
(#37627)
f36d17ce75 is described below
commit f36d17ce757e5c39ca4dd35a637a090ef6105744
Author: Vincent <[email protected]>
AuthorDate: Fri Feb 23 14:56:09 2024 -0500
Implement `filter_permitted_menu_items` in AWS auth manager (#37627)
---
airflow/auth/managers/base_auth_manager.py | 2 +-
.../amazon/aws/auth_manager/avp/facade.py | 134 +++++++++----
.../amazon/aws/auth_manager/aws_auth_manager.py | 213 ++++++++++++++++++++-
airflow/www/templates/appbuilder/navbar_menu.html | 2 +-
docs/apache-airflow/core-concepts/auth-manager.rst | 2 +-
newsfragments/37627.significant.rst | 1 +
tests/auth/managers/test_base_auth_manager.py | 4 +-
.../amazon/aws/auth_manager/avp/test_facade.py | 69 +++++++
.../aws/auth_manager/test_aws_auth_manager.py | 133 +++++++++++++
9 files changed, 517 insertions(+), 43 deletions(-)
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index b9dd459657..88164b644e 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -367,7 +367,7 @@ class BaseAuthManager(LoggingMixin):
if _is_permitted_dag_id("GET", methods, dag_id) or
_is_permitted_dag_id("PUT", methods, dag_id)
}
- def get_permitted_menu_items(self, menu_items: list[MenuItem]) ->
list[MenuItem]:
+ def filter_permitted_menu_items(self, menu_items: list[MenuItem]) ->
list[MenuItem]:
"""
Filter menu items based on user permissions.
diff --git a/airflow/providers/amazon/aws/auth_manager/avp/facade.py
b/airflow/providers/amazon/aws/auth_manager/avp/facade.py
index c13233b54e..18231b3d0f 100644
--- a/airflow/providers/amazon/aws/auth_manager/avp/facade.py
+++ b/airflow/providers/amazon/aws/auth_manager/avp/facade.py
@@ -37,6 +37,11 @@ if TYPE_CHECKING:
from airflow.providers.amazon.aws.auth_manager.user import
AwsAuthManagerUser
+# Amazon Verified Permissions allows only up to 30 requests per
batch_is_authorized call. See
+#
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/verifiedpermissions/client/batch_is_authorized.html
+NB_REQUESTS_PER_BATCH = 30
+
+
class IsAuthorizedRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized`` method in AVP facade."""
@@ -125,62 +130,97 @@ class
AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
return resp["decision"] == "ALLOW"
- def batch_is_authorized(
+ def get_batch_is_authorized_results(
self,
*,
requests: Sequence[IsAuthorizedRequest],
- user: AwsAuthManagerUser | None,
- ) -> bool:
+ user: AwsAuthManagerUser,
+ ) -> list[dict]:
"""
Make a batch authorization decision against Amazon Verified
Permissions.
- Check whether the user has permissions to access given resources.
+ Return a list of results for each request.
:param requests: the list of requests containing the method, the
entity_type and the entity ID
:param user: the user
"""
- if user is None:
- return False
-
entity_list = self._get_user_role_entities(user)
self.log.debug("Making batch authorization request for user=%s,
requests=%s", user.get_id(), requests)
- avp_requests = [
- prune_dict(
- {
- "principal": {"entityType":
get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
- "action": {
- "actionType": get_entity_type(AvpEntities.ACTION),
- "actionId": get_action_id(request["entity_type"],
request["method"]),
- },
- "resource": {
- "entityType": get_entity_type(request["entity_type"]),
- "entityId": request.get("entity_id", "*"),
- },
- "context": self._build_context(request.get("context")),
- }
- )
- for request in requests
+ avp_requests = [self._build_is_authorized_request_payload(request,
user) for request in requests]
+ avp_requests_chunks = [
+ avp_requests[i : i + NB_REQUESTS_PER_BATCH]
+ for i in range(0, len(avp_requests), NB_REQUESTS_PER_BATCH)
]
- resp = self.avp_client.batch_is_authorized(
- policyStoreId=self.avp_policy_store_id,
- requests=avp_requests,
- entities={"entityList": entity_list},
- )
+ results = []
+ for avp_requests in avp_requests_chunks:
+ resp = self.avp_client.batch_is_authorized(
+ policyStoreId=self.avp_policy_store_id,
+ requests=avp_requests,
+ entities={"entityList": entity_list},
+ )
- self.log.debug("Authorization response: %s", resp)
+ self.log.debug("Authorization response: %s", resp)
- has_errors = any(len(result.get("errors", [])) > 0 for result in
resp["results"])
+ has_errors = any(len(result.get("errors", [])) > 0 for result in
resp["results"])
- if has_errors:
- self.log.error(
- "Error occurred while making a batch authorization decision.
Result: %s", resp["results"]
- )
- raise AirflowException("Error occurred while making a batch
authorization decision.")
+ if has_errors:
+ self.log.error(
+ "Error occurred while making a batch authorization
decision. Result: %s", resp["results"]
+ )
+ raise AirflowException("Error occurred while making a batch
authorization decision.")
+
+ results.extend(resp["results"])
+
+ return results
+
+ def batch_is_authorized(
+ self,
+ *,
+ requests: Sequence[IsAuthorizedRequest],
+ user: AwsAuthManagerUser | None,
+ ) -> bool:
+ """
+ Make a batch authorization decision against Amazon Verified
Permissions.
+
+ Check whether the user has permissions to access all resources.
+
+ :param requests: the list of requests containing the method, the
entity_type and the entity ID
+ :param user: the user
+ """
+ if user is None:
+ return False
+ results = self.get_batch_is_authorized_results(requests=requests,
user=user)
+ return all(result["decision"] == "ALLOW" for result in results)
+
+ def get_batch_is_authorized_single_result(
+ self,
+ *,
+ batch_is_authorized_results: list[dict],
+ request: IsAuthorizedRequest,
+ user: AwsAuthManagerUser,
+ ) -> dict:
+ """
+ Get a specific authorization result from the output of
``get_batch_is_authorized_results``.
- return all(result["decision"] == "ALLOW" for result in resp["results"])
+ :param batch_is_authorized_results: the response from the
``batch_is_authorized`` API
+ :param request: the request information. Used to find the result in
the response.
+ :param user: the user
+ """
+ request_payload = self._build_is_authorized_request_payload(request,
user)
+
+ for result in batch_is_authorized_results:
+ if result["request"] == request_payload:
+ return result
+
+ self.log.error(
+ "Could not find the authorization result for request %s in results
%s.",
+ request_payload,
+ batch_is_authorized_results,
+ )
+ raise AirflowException("Could not find the authorization result.")
@staticmethod
def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
@@ -205,3 +245,25 @@ class
AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
return {
"contextMap": context,
}
+
+ def _build_is_authorized_request_payload(self, request:
IsAuthorizedRequest, user: AwsAuthManagerUser):
+ """
+ Build a payload of an individual authorization request that could be
sent through the ``batch_is_authorized`` API.
+
+ :param request: the request information
+ :param user: the user
+ """
+ return prune_dict(
+ {
+ "principal": {"entityType": get_entity_type(AvpEntities.USER),
"entityId": user.get_id()},
+ "action": {
+ "actionType": get_entity_type(AvpEntities.ACTION),
+ "actionId": get_action_id(request["entity_type"],
request["method"]),
+ },
+ "resource": {
+ "entityType": get_entity_type(request["entity_type"]),
+ "entityId": request.get("entity_id", "*"),
+ },
+ "context": self._build_context(request.get("context")),
+ }
+ )
diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 5ed09d3810..0aa1fdb2c5 100644
--- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -24,7 +24,7 @@ from flask import session, url_for
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
-from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.exceptions import AirflowException,
AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
AwsAuthManagerAmazonVerifiedPermissionsFacade,
@@ -40,10 +40,33 @@ from airflow.providers.amazon.aws.auth_manager.constants
import (
from
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
import (
AwsSecurityManagerOverride,
)
+from airflow.security.permissions import (
+ RESOURCE_AUDIT_LOG,
+ RESOURCE_CLUSTER_ACTIVITY,
+ RESOURCE_CONFIG,
+ RESOURCE_CONNECTION,
+ RESOURCE_DAG,
+ RESOURCE_DAG_CODE,
+ RESOURCE_DAG_DEPENDENCIES,
+ RESOURCE_DAG_RUN,
+ RESOURCE_DATASET,
+ RESOURCE_DOCS,
+ RESOURCE_JOB,
+ RESOURCE_PLUGIN,
+ RESOURCE_POOL,
+ RESOURCE_PROVIDER,
+ RESOURCE_SLA_MISS,
+ RESOURCE_TASK_INSTANCE,
+ RESOURCE_TASK_RESCHEDULE,
+ RESOURCE_TRIGGER,
+ RESOURCE_VARIABLE,
+ RESOURCE_XCOM,
+)
try:
from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
from airflow.auth.managers.models.resource_details import (
+ AccessView,
ConnectionDetails,
DagAccessEntity,
DagDetails,
@@ -56,6 +79,8 @@ except ImportError:
)
if TYPE_CHECKING:
+ from flask_appbuilder.menu import MenuItem
+
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
@@ -64,7 +89,6 @@ if TYPE_CHECKING:
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import (
- AccessView,
ConfigurationDetails,
DatasetDetails,
)
@@ -72,6 +96,136 @@ if TYPE_CHECKING:
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
+_MENU_ITEM_REQUESTS: dict[str, IsAuthorizedRequest] = {
+ RESOURCE_AUDIT_LOG: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.AUDIT_LOG.value,
+ },
+ },
+ },
+ RESOURCE_CLUSTER_ACTIVITY: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.CLUSTER_ACTIVITY.value,
+ },
+ RESOURCE_CONFIG: {
+ "method": "GET",
+ "entity_type": AvpEntities.CONFIGURATION,
+ },
+ RESOURCE_CONNECTION: {
+ "method": "GET",
+ "entity_type": AvpEntities.CONNECTION,
+ },
+ RESOURCE_DAG: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ },
+ RESOURCE_DAG_CODE: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.CODE.value,
+ },
+ },
+ },
+ RESOURCE_DAG_DEPENDENCIES: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.DEPENDENCIES.value,
+ },
+ },
+ },
+ RESOURCE_DAG_RUN: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.RUN.value,
+ },
+ },
+ },
+ RESOURCE_DATASET: {
+ "method": "GET",
+ "entity_type": AvpEntities.DATASET,
+ },
+ RESOURCE_DOCS: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.DOCS.value,
+ },
+ RESOURCE_PLUGIN: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.PLUGINS.value,
+ },
+ RESOURCE_JOB: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.JOBS.value,
+ },
+ RESOURCE_POOL: {
+ "method": "GET",
+ "entity_type": AvpEntities.POOL,
+ },
+ RESOURCE_PROVIDER: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.PROVIDERS.value,
+ },
+ RESOURCE_SLA_MISS: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.SLA_MISS.value,
+ },
+ },
+ },
+ RESOURCE_TASK_INSTANCE: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.TASK_INSTANCE.value,
+ },
+ },
+ },
+ RESOURCE_TASK_RESCHEDULE: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.TASK_RESCHEDULE.value,
+ },
+ },
+ },
+ RESOURCE_TRIGGER: {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.TRIGGERS.value,
+ },
+ RESOURCE_VARIABLE: {
+ "method": "GET",
+ "entity_type": AvpEntities.VARIABLE,
+ },
+ RESOURCE_XCOM: {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.XCOM.value,
+ },
+ },
+ },
+}
+
+
class AwsAuthManager(BaseAuthManager):
"""
AWS auth manager.
@@ -289,6 +443,45 @@ class AwsAuthManager(BaseAuthManager):
]
return self.avp_facade.batch_is_authorized(requests=facade_requests,
user=self.get_user())
+ def filter_permitted_menu_items(self, menu_items: list[MenuItem]) ->
list[MenuItem]:
+ """
+ Filter menu items based on user permissions.
+
+ :param menu_items: list of all menu items
+ """
+ user = self.get_user()
+ if not user:
+ return []
+
+ requests: dict[str, IsAuthorizedRequest] = {}
+ for menu_item in menu_items:
+ if menu_item.childs:
+ for child in menu_item.childs:
+ requests[child.name] =
self._get_menu_item_request(child.name)
+ else:
+ requests[menu_item.name] =
self._get_menu_item_request(menu_item.name)
+
+ batch_is_authorized_results =
self.avp_facade.get_batch_is_authorized_results(
+ requests=list(requests.values()), user=user
+ )
+
+ accessible_items = []
+ for menu_item in menu_items:
+ if menu_item.childs:
+ accessible_children = []
+ for child in menu_item.childs:
+ if
self._has_access_to_menu_item(batch_is_authorized_results,
requests[child.name], user):
+ accessible_children.append(child)
+ menu_item.childs = accessible_children
+
+ # Display the menu if the user has access to at least one sub
item
+ if len(accessible_children) > 0:
+ accessible_items.append(menu_item)
+ elif self._has_access_to_menu_item(batch_is_authorized_results,
requests[menu_item.name], user):
+ accessible_items.append(menu_item)
+
+ return accessible_items
+
def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")
@@ -310,6 +503,22 @@ class AwsAuthManager(BaseAuthManager):
),
]
+ @staticmethod
+ def _get_menu_item_request(fab_resource_name: str) -> IsAuthorizedRequest:
+ menu_item_request = _MENU_ITEM_REQUESTS.get(fab_resource_name)
+ if menu_item_request:
+ return menu_item_request
+ else:
+ raise AirflowException(f"Unknown resource name
{fab_resource_name}")
+
+ def _has_access_to_menu_item(
+ self, batch_is_authorized_results: list[dict], request:
IsAuthorizedRequest, user: AwsAuthManagerUser
+ ):
+ result = self.avp_facade.get_batch_is_authorized_single_result(
+ batch_is_authorized_results=batch_is_authorized_results,
request=request, user=user
+ )
+ return result["decision"] == "ALLOW"
+
def get_parser() -> argparse.ArgumentParser:
"""Generate documentation; used by Sphinx argparse."""
diff --git a/airflow/www/templates/appbuilder/navbar_menu.html
b/airflow/www/templates/appbuilder/navbar_menu.html
index 72855df270..d7abc8472f 100644
--- a/airflow/www/templates/appbuilder/navbar_menu.html
+++ b/airflow/www/templates/appbuilder/navbar_menu.html
@@ -21,7 +21,7 @@
<a href="{{item.get_url()}}">{{_(item.label)}}</a>
{% endmacro %}
-{% for item1 in auth_manager.get_permitted_menu_items(menu.get_list()) %}
+{% for item1 in auth_manager.filter_permitted_menu_items(menu.get_list()) %}
{% if item1 %}
{% if item1.childs %}
<li class="dropdown">
diff --git a/docs/apache-airflow/core-concepts/auth-manager.rst
b/docs/apache-airflow/core-concepts/auth-manager.rst
index bb54795f18..4e3446acaa 100644
--- a/docs/apache-airflow/core-concepts/auth-manager.rst
+++ b/docs/apache-airflow/core-concepts/auth-manager.rst
@@ -124,7 +124,7 @@ The following methods aren't required to override to have a
functional Airflow a
* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If
not overridden, it will call ``is_authorized_pool`` for every single item.
* ``batch_is_authorized_variable``: Batch version of
``is_authorized_variable``. If not overridden, it will call
``is_authorized_variable`` for every single item.
* ``get_permitted_dag_ids``: Return the list of DAG IDs the user has access
to. If not overridden, it will call ``is_authorized_dag`` for every single DAG
available in the environment.
-* ``get_permitted_menu_items``: Return the menu items the user has access to.
If not overridden, it will call ``has_access`` in
:class:`~airflow.www.security_manager.AirflowSecurityManagerV2` for every
single menu item.
+* ``filter_permitted_menu_items``: Return the menu items the user has access
to. If not overridden, it will call ``has_access`` in
:class:`~airflow.www.security_manager.AirflowSecurityManagerV2` for every
single menu item.
CLI
^^^
diff --git a/newsfragments/37627.significant.rst
b/newsfragments/37627.significant.rst
new file mode 100644
index 0000000000..886b3e25a0
--- /dev/null
+++ b/newsfragments/37627.significant.rst
@@ -0,0 +1 @@
+The method ``get_permitted_menu_items`` in ``BaseAuthManager`` has been
renamed ``filter_permitted_menu_items``
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index 9fac23a062..d05bb50dd4 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -299,7 +299,7 @@ class TestBaseAuthManager:
assert result == expected
@patch.object(EmptyAuthManager, "security_manager")
- def test_get_permitted_menu_items(self, mock_security_manager,
auth_manager):
+ def test_filter_permitted_menu_items(self, mock_security_manager,
auth_manager):
mock_security_manager.has_access.side_effect = [True, False, True,
True, False]
menu = Menu()
@@ -309,7 +309,7 @@ class TestBaseAuthManager:
menu.add_link("item3.1", category="item3")
menu.add_link("item3.2", category="item3")
- result = auth_manager.get_permitted_menu_items(menu.get_list())
+ result = auth_manager.filter_permitted_menu_items(menu.get_list())
assert len(result) == 2
assert result[0].name == "item1"
diff --git a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
index fca4961dfe..437f731b9f 100644
--- a/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
+++ b/tests/providers/amazon/aws/auth_manager/avp/test_facade.py
@@ -280,3 +280,72 @@ class TestAwsAuthManagerAmazonVerifiedPermissionsFacade:
],
user=test_user,
)
+
+ def test_get_batch_is_authorized_single_result_successful(self, facade):
+ single_result = {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Connection.GET"},
+ "resource": {"entityType": "Airflow::Connection", "entityId":
"*"},
+ },
+ "decision": "ALLOW",
+ }
+
+ with conf_vars(
+ {
+ ("aws_auth_manager", "avp_policy_store_id"):
AVP_POLICY_STORE_ID,
+ }
+ ):
+ result = facade.get_batch_is_authorized_single_result(
+ batch_is_authorized_results=[
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User",
"entityId": "test_user"},
+ "action": {"actionType": "Airflow::Action",
"actionId": "Variable.GET"},
+ "resource": {"entityType": "Airflow::Variable",
"entityId": "*"},
+ },
+ "decision": "ALLOW",
+ },
+ single_result,
+ ],
+ request={
+ "method": "GET",
+ "entity_type": AvpEntities.CONNECTION,
+ },
+ user=test_user,
+ )
+
+ assert result == single_result
+
+ def test_get_batch_is_authorized_single_result_unsuccessful(self, facade):
+ with conf_vars(
+ {
+ ("aws_auth_manager", "avp_policy_store_id"):
AVP_POLICY_STORE_ID,
+ }
+ ):
+ with pytest.raises(AirflowException, match="Could not find the
authorization result."):
+ facade.get_batch_is_authorized_single_result(
+ batch_is_authorized_results=[
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User",
"entityId": "test_user"},
+ "action": {"actionType": "Airflow::Action",
"actionId": "Variable.GET"},
+ "resource": {"entityType":
"Airflow::Variable", "entityId": "*"},
+ },
+ "decision": "ALLOW",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User",
"entityId": "test_user"},
+ "action": {"actionType": "Airflow::Action",
"actionId": "Variable.POST"},
+ "resource": {"entityType":
"Airflow::Variable", "entityId": "*"},
+ },
+ "decision": "ALLOW",
+ },
+ ],
+ request={
+ "method": "GET",
+ "entity_type": AvpEntities.CONNECTION,
+ },
+ user=test_user,
+ )
diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
index df4c45255f..daa21f21c1 100644
--- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
+++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -21,6 +21,7 @@ from unittest.mock import ANY, Mock, patch
import pytest
from flask import Flask, session
+from flask_appbuilder.menu import MenuItem
from airflow.auth.managers.models.resource_details import (
AccessView,
@@ -32,12 +33,20 @@ from airflow.auth.managers.models.resource_details import (
PoolDetails,
VariableDetails,
)
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import
AwsAuthManager
from
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
import (
AwsSecurityManagerOverride,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
+from airflow.security.permissions import (
+ RESOURCE_AUDIT_LOG,
+ RESOURCE_CLUSTER_ACTIVITY,
+ RESOURCE_CONNECTION,
+ RESOURCE_DATASET,
+ RESOURCE_VARIABLE,
+)
from airflow.www.extensions.init_appbuilder import init_appbuilder
from tests.test_utils.config import conf_vars
@@ -509,6 +518,130 @@ class TestAwsAuthManager:
)
assert result
+ @patch.object(AwsAuthManager, "get_user")
+ def test_filter_permitted_menu_items(self, mock_get_user, auth_manager,
test_user):
+ batch_is_authorized_output = [
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Connection.GET"},
+ "resource": {"entityType": "Airflow::Connection",
"entityId": "*"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Variable.GET"},
+ "resource": {"entityType": "Airflow::Variable",
"entityId": "*"},
+ },
+ "decision": "ALLOW",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dataset.GET"},
+ "resource": {"entityType": "Airflow::Dataset", "entityId":
"*"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"View.GET"},
+ "resource": {"entityType": "Airflow::View", "entityId":
"CLUSTER_ACTIVITY"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dag.GET"},
+ "resource": {"entityType": "Airflow::Dag", "entityId":
"*"},
+ "context": {
+ "contextMap": {
+ "dag_entity": {
+ "string": "AUDIT_LOG",
+ }
+ }
+ },
+ },
+ "decision": "ALLOW",
+ },
+ ]
+ auth_manager.avp_facade.get_batch_is_authorized_results = Mock(
+ return_value=batch_is_authorized_output
+ )
+
+ mock_get_user.return_value = test_user
+
+ result = auth_manager.filter_permitted_menu_items(
+ [
+ MenuItem("Category1", childs=[MenuItem(RESOURCE_CONNECTION),
MenuItem(RESOURCE_VARIABLE)]),
+ MenuItem("Category2", childs=[MenuItem(RESOURCE_DATASET)]),
+ MenuItem(RESOURCE_CLUSTER_ACTIVITY),
+ MenuItem(RESOURCE_AUDIT_LOG),
+ ]
+ )
+
+
auth_manager.avp_facade.get_batch_is_authorized_results.assert_called_once_with(
+ requests=[
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.CONNECTION,
+ },
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.VARIABLE,
+ },
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.DATASET,
+ },
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.VIEW,
+ "entity_id": AccessView.CLUSTER_ACTIVITY.value,
+ },
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.DAG,
+ "context": {
+ "dag_entity": {
+ "string": DagAccessEntity.AUDIT_LOG.value,
+ },
+ },
+ },
+ ],
+ user=test_user,
+ )
+ assert len(result) == 2
+ assert result[0].name == "Category1"
+ assert len(result[0].childs) == 1
+ assert result[0].childs[0].name == RESOURCE_VARIABLE
+ assert result[1].name == RESOURCE_AUDIT_LOG
+
+ @patch.object(AwsAuthManager, "get_user")
+ def test_filter_permitted_menu_items_logged_out(self, mock_get_user,
auth_manager):
+ mock_get_user.return_value = None
+ result = auth_manager.filter_permitted_menu_items(
+ [
+ MenuItem(RESOURCE_AUDIT_LOG),
+ ]
+ )
+
+ assert result == []
+
+ @patch.object(AwsAuthManager, "get_user")
+ def test_filter_permitted_menu_items_wrong_menu_item(self, mock_get_user,
auth_manager, test_user):
+ mock_get_user.return_value = test_user
+ with pytest.raises(AirflowException, match="Unknown resource name"):
+ auth_manager.filter_permitted_menu_items(
+ [
+ MenuItem("Test"),
+ ]
+ )
+
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
def test_get_url_login(self, mock_url_for, auth_manager):
auth_manager.get_url_login()