This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 47348ce66c Implement `filter_permitted_dag_ids` in AWS auth manager 
(#37666)
47348ce66c is described below

commit 47348ce66c1f4a90aec87215b8a237c4bffcddca
Author: Vincent <97131062+vincb...@users.noreply.github.com>
AuthorDate: Fri Mar 8 14:01:45 2024 -0500

    Implement `filter_permitted_dag_ids` in AWS auth manager (#37666)
---
 airflow/auth/managers/base_auth_manager.py         | 23 ++++++-
 .../amazon/aws/auth_manager/aws_auth_manager.py    | 75 ++++++++++++++++++----
 .../aws/auth_manager/test_aws_auth_manager.py      | 59 +++++++++++++++++
 3 files changed, 144 insertions(+), 13 deletions(-)

diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index a0176fc833..e378e5e10b 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -344,12 +344,31 @@ class BaseAuthManager(LoggingMixin):
         By default, reads all the DAGs and check individually if the user has 
permissions to access the DAG.
         Can lead to some poor performance. It is recommended to override this 
method in the auth manager
         implementation to provide a more efficient implementation.
+
+        :param methods: whether filter readable or writable
+        :param user: the current user
+        :param session: the session
+        """
+        dag_ids = {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
+        return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, 
user=user)
+
+    def filter_permitted_dag_ids(
+        self,
+        *,
+        dag_ids: set[str],
+        methods: Container[ResourceMethod] | None = None,
+        user=None,
+    ):
+        """
+        Filter readable or writable DAGs for user.
+
+        :param dag_ids: the list of DAG ids
+        :param methods: whether filter readable or writable
+        :param user: the current user
         """
         if not methods:
             methods = ["PUT", "GET"]
 
-        dag_ids = {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
-
         if ("GET" in methods and self.is_authorized_dag(method="GET", 
user=user)) or (
             "PUT" in methods and self.is_authorized_dag(method="PUT", 
user=user)
         ):
diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py 
b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 0aa1fdb2c5..c17234c047 100644
--- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -17,8 +17,9 @@
 from __future__ import annotations
 
 import argparse
+from collections import defaultdict
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence, cast
+from typing import TYPE_CHECKING, Container, Sequence, cast
 
 from flask import session, url_for
 
@@ -443,6 +444,60 @@ class AwsAuthManager(BaseAuthManager):
         ]
         return self.avp_facade.batch_is_authorized(requests=facade_requests, 
user=self.get_user())
 
+    def filter_permitted_dag_ids(
+        self,
+        *,
+        dag_ids: set[str],
+        methods: Container[ResourceMethod] | None = None,
+        user=None,
+    ):
+        """
+        Filter readable or writable DAGs for user.
+
+        :param dag_ids: the list of DAG ids
+        :param methods: whether filter readable or writable
+        :param user: the current user
+        """
+        if not methods:
+            methods = ["PUT", "GET"]
+
+        if not user:
+            user = self.get_user()
+
+        requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = 
defaultdict(dict)
+        requests_list: list[IsAuthorizedRequest] = []
+        for dag_id in dag_ids:
+            for method in ["GET", "PUT"]:
+                if method in methods:
+                    request: IsAuthorizedRequest = {
+                        "method": cast(ResourceMethod, method),
+                        "entity_type": AvpEntities.DAG,
+                        "entity_id": dag_id,
+                    }
+                    requests[dag_id][cast(ResourceMethod, method)] = request
+                    requests_list.append(request)
+
+        batch_is_authorized_results = 
self.avp_facade.get_batch_is_authorized_results(
+            requests=requests_list, user=user
+        )
+
+        def _has_access_to_dag(request: IsAuthorizedRequest):
+            result = self.avp_facade.get_batch_is_authorized_single_result(
+                batch_is_authorized_results=batch_is_authorized_results, 
request=request, user=user
+            )
+            return result["decision"] == "ALLOW"
+
+        return {
+            dag_id
+            for dag_id in dag_ids
+            if (
+                "GET" in methods
+                and _has_access_to_dag(requests[dag_id]["GET"])
+                or "PUT" in methods
+                and _has_access_to_dag(requests[dag_id]["PUT"])
+            )
+        }
+
     def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> 
list[MenuItem]:
         """
         Filter menu items based on user permissions.
@@ -465,19 +520,25 @@ class AwsAuthManager(BaseAuthManager):
             requests=list(requests.values()), user=user
         )
 
+        def _has_access_to_menu_item(request: IsAuthorizedRequest):
+            result = self.avp_facade.get_batch_is_authorized_single_result(
+                batch_is_authorized_results=batch_is_authorized_results, 
request=request, user=user
+            )
+            return result["decision"] == "ALLOW"
+
         accessible_items = []
         for menu_item in menu_items:
             if menu_item.childs:
                 accessible_children = []
                 for child in menu_item.childs:
-                    if 
self._has_access_to_menu_item(batch_is_authorized_results, 
requests[child.name], user):
+                    if _has_access_to_menu_item(requests[child.name]):
                         accessible_children.append(child)
                 menu_item.childs = accessible_children
 
                 # Display the menu if the user has access to at least one sub 
item
                 if len(accessible_children) > 0:
                     accessible_items.append(menu_item)
-            elif self._has_access_to_menu_item(batch_is_authorized_results, 
requests[menu_item.name], user):
+            elif _has_access_to_menu_item(requests[menu_item.name]):
                 accessible_items.append(menu_item)
 
         return accessible_items
@@ -511,14 +572,6 @@ class AwsAuthManager(BaseAuthManager):
         else:
             raise AirflowException(f"Unknown resource name 
{fab_resource_name}")
 
-    def _has_access_to_menu_item(
-        self, batch_is_authorized_results: list[dict], request: 
IsAuthorizedRequest, user: AwsAuthManagerUser
-    ):
-        result = self.avp_facade.get_batch_is_authorized_single_result(
-            batch_is_authorized_results=batch_is_authorized_results, 
request=request, user=user
-        )
-        return result["decision"] == "ALLOW"
-
 
 def get_parser() -> argparse.ArgumentParser:
     """Generate documentation; used by Sphinx argparse."""
diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py 
b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
index daa21f21c1..9b654199c3 100644
--- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
+++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -642,6 +642,65 @@ class TestAwsAuthManager:
                 ]
             )
 
+    @pytest.mark.parametrize(
+        "methods, user",
+        [
+            (None, None),
+            (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", 
groups=[])),
+        ],
+    )
+    @patch.object(AwsAuthManager, "get_user")
+    def test_filter_permitted_dag_ids(self, mock_get_user, methods, user, 
auth_manager, test_user):
+        dag_ids = {"dag_1", "dag_2"}
+        batch_is_authorized_output = [
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
+                },
+                "decision": "ALLOW",
+            },
+        ]
+        auth_manager.avp_facade.get_batch_is_authorized_results = Mock(
+            return_value=batch_is_authorized_output
+        )
+
+        mock_get_user.return_value = test_user
+
+        result = auth_manager.filter_permitted_dag_ids(
+            dag_ids=dag_ids,
+            methods=methods,
+            user=user,
+        )
+
+        auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
+        assert result == {"dag_2"}
+
     
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
     def test_get_url_login(self, mock_url_for, auth_manager):
         auth_manager.get_url_login()

Reply via email to