This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f10c4314aab Fix and simplify `get_permitted_dag_ids` in auth manager 
(#47458)
f10c4314aab is described below

commit f10c4314aab2ab98c94e6c277d9c8019eba3a9f6
Author: Vincent <[email protected]>
AuthorDate: Thu Mar 6 14:24:49 2025 -0500

    Fix and simplify `get_permitted_dag_ids` in auth manager (#47458)
---
 airflow/auth/managers/base_auth_manager.py         | 32 +++------
 .../amazon/aws/auth_manager/aws_auth_manager.py    | 34 +++------
 .../aws/auth_manager/test_aws_auth_manager.py      | 56 ++++++++++++---
 .../providers/fab/auth_manager/fab_auth_manager.py | 40 +++++------
 .../fab/auth_manager/security_manager/override.py  |  4 +-
 .../unit/fab/auth_manager/test_fab_auth_manager.py | 80 +++++++++++++++++++++-
 .../tests/unit/fab/auth_manager/test_security.py   |  2 +-
 tests/auth/managers/test_base_auth_manager.py      | 17 +----
 8 files changed, 167 insertions(+), 98 deletions(-)

diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index 0ca18db8121..f3b86600a05 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -34,7 +34,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
-    from collections.abc import Container, Sequence
+    from collections.abc import Sequence
 
     from fastapi import FastAPI
     from sqlalchemy.orm import Session
@@ -331,7 +331,7 @@ class BaseAuthManager(Generic[T], LoggingMixin):
         self,
         *,
         user: T,
-        methods: Container[ResourceMethod] | None = None,
+        method: ResourceMethod = "GET",
         session: Session = NEW_SESSION,
     ) -> set[str]:
         """
@@ -342,45 +342,31 @@ class BaseAuthManager(Generic[T], LoggingMixin):
         implementation to provide a more efficient implementation.
 
         :param user: the user
-        :param methods: whether filter readable or writable
+        :param method: the method to filter on
         :param session: the session
         """
         dag_ids = {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
-        return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, 
user=user)
+        return self.filter_permitted_dag_ids(dag_ids=dag_ids, method=method, 
user=user)
 
     def filter_permitted_dag_ids(
         self,
         *,
         dag_ids: set[str],
         user: T,
-        methods: Container[ResourceMethod] | None = None,
+        method: ResourceMethod = "GET",
     ) -> set[str]:
         """
         Filter readable or writable DAGs for user.
 
         :param dag_ids: the list of DAG ids
         :param user: the user
-        :param methods: whether filter readable or writable
+        :param method: the method to filter on
         """
-        if not methods:
-            methods = ["PUT", "GET"]
 
-        if ("GET" in methods and self.is_authorized_dag(method="GET", 
user=user)) or (
-            "PUT" in methods and self.is_authorized_dag(method="PUT", 
user=user)
-        ):
-            # If user is authorized to read/edit all DAGs, return all DAGs
-            return dag_ids
+        def _is_permitted_dag_id(method: ResourceMethod, dag_id: str):
+            return self.is_authorized_dag(method=method, 
details=DagDetails(id=dag_id), user=user)
 
-        def _is_permitted_dag_id(method: ResourceMethod, methods: 
Container[ResourceMethod], dag_id: str):
-            return method in methods and self.is_authorized_dag(
-                method=method, details=DagDetails(id=dag_id), user=user
-            )
-
-        return {
-            dag_id
-            for dag_id in dag_ids
-            if _is_permitted_dag_id("GET", methods, dag_id) or 
_is_permitted_dag_id("PUT", methods, dag_id)
-        }
+        return {dag_id for dag_id in dag_ids if _is_permitted_dag_id(method, 
dag_id)}
 
     @staticmethod
     def get_cli_commands() -> list[CLICommand]:
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 6c992438a96..04b547b2979 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import argparse
 from collections import defaultdict
-from collections.abc import Container, Sequence
+from collections.abc import Sequence
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, cast
 
@@ -283,23 +283,18 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
         *,
         dag_ids: set[str],
         user: AwsAuthManagerUser,
-        methods: Container[ResourceMethod] | None = None,
+        method: ResourceMethod = "GET",
     ):
-        if not methods:
-            methods = ["PUT", "GET"]
-
         requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = 
defaultdict(dict)
         requests_list: list[IsAuthorizedRequest] = []
         for dag_id in dag_ids:
-            for method in ["GET", "PUT"]:
-                if method in methods:
-                    request: IsAuthorizedRequest = {
-                        "method": cast("ResourceMethod", method),
-                        "entity_type": AvpEntities.DAG,
-                        "entity_id": dag_id,
-                    }
-                    requests[dag_id][cast("ResourceMethod", method)] = request
-                    requests_list.append(request)
+            request: IsAuthorizedRequest = {
+                "method": method,
+                "entity_type": AvpEntities.DAG,
+                "entity_id": dag_id,
+            }
+            requests[dag_id][method] = request
+            requests_list.append(request)
 
         batch_is_authorized_results = 
self.avp_facade.get_batch_is_authorized_results(
             requests=requests_list, user=user
@@ -311,16 +306,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
             )
             return result["decision"] == "ALLOW"
 
-        return {
-            dag_id
-            for dag_id in dag_ids
-            if (
-                "GET" in methods
-                and _has_access_to_dag(requests[dag_id]["GET"])
-                or "PUT" in methods
-                and _has_access_to_dag(requests[dag_id]["PUT"])
-            )
-        }
+        return {dag_id for dag_id in dag_ids if 
_has_access_to_dag(requests[dag_id][method])}
 
     def get_url_login(self, **kwargs) -> str:
         return f"{self.apiserver_endpoint}/auth/login"
diff --git 
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py 
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
index be3485d40c1..45935e4b686 100644
--- 
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
+++ 
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -445,26 +445,62 @@ class TestAwsAuthManager:
         assert result
 
     @pytest.mark.parametrize(
-        "methods, user",
+        "method, user, expected_result",
         [
-            (None, AwsAuthManagerUser(user_id="test_user_id", groups=[])),
-            (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", 
groups=[])),
+            ("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), 
{"dag_1"}),
+            ("PUT", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), 
set()),
+            ("GET", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), 
set()),
+            ("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), 
{"dag_2"}),
         ],
     )
-    def test_filter_permitted_dag_ids(self, methods, user, auth_manager, 
test_user):
+    def test_filter_permitted_dag_ids(self, method, user, auth_manager, 
test_user, expected_result):
         dag_ids = {"dag_1", "dag_2"}
+        # test_user_id1 has GET permissions on dag_1
+        # test_user_id2 has PUT permissions on dag_2
         batch_is_authorized_output = [
             {
                 "request": {
-                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id1"},
                     "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
                     "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
                 },
+                "decision": "ALLOW",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id1"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
+                },
                 "decision": "DENY",
             },
             {
                 "request": {
-                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id1"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id1"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id2"},
+                    "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
+                    "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
+                },
+                "decision": "DENY",
+            },
+            {
+                "request": {
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id2"},
                     "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
                     "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_1"},
                 },
@@ -472,7 +508,7 @@ class TestAwsAuthManager:
             },
             {
                 "request": {
-                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id2"},
                     "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.GET"},
                     "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
                 },
@@ -480,7 +516,7 @@ class TestAwsAuthManager:
             },
             {
                 "request": {
-                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id"},
+                    "principal": {"entityType": "Airflow::User", "entityId": 
"test_user_id2"},
                     "action": {"actionType": "Airflow::Action", "actionId": 
"Dag.PUT"},
                     "resource": {"entityType": "Airflow::Dag", "entityId": 
"dag_2"},
                 },
@@ -493,12 +529,12 @@ class TestAwsAuthManager:
 
         result = auth_manager.filter_permitted_dag_ids(
             dag_ids=dag_ids,
-            methods=methods,
+            method=method,
             user=user,
         )
 
         auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
-        assert result == {"dag_2"}
+        assert result == expected_result
 
     def test_get_url_login(self, auth_manager):
         result = auth_manager.get_url_login()
diff --git 
a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py 
b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index b416d31e212..eb3ef724ab6 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -18,7 +18,6 @@
 from __future__ import annotations
 
 import argparse
-from collections.abc import Container
 from functools import cached_property
 from pathlib import Path
 from typing import TYPE_CHECKING, Any
@@ -58,6 +57,7 @@ from 
airflow.providers.fab.auth_manager.cli_commands.definition import (
     USERS_COMMANDS,
 )
 from airflow.providers.fab.auth_manager.models import Permission, Role, User
+from airflow.providers.fab.auth_manager.models.anonymous_user import 
AnonymousUser
 from airflow.providers.fab.www.app import create_app
 from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
 from airflow.providers.fab.www.extensions.init_views import 
_CustomErrorRequestBodyValidator, _LazyResolver
@@ -355,30 +355,24 @@ class FabAuthManager(BaseAuthManager[User]):
         self,
         *,
         user: User,
-        methods: Container[ResourceMethod] | None = None,
+        method: ResourceMethod = "GET",
         session: Session = NEW_SESSION,
     ) -> set[str]:
-        if not methods:
-            methods = ["PUT", "GET"]
-
-        if not self.is_logged_in():
-            roles = user.roles
-        else:
-            if ("GET" in methods and self.is_authorized_dag(method="GET", 
user=user)) or (
-                "PUT" in methods and self.is_authorized_dag(method="PUT", 
user=user)
-            ):
-                # If user is authorized to read/edit all DAGs, return all DAGs
-                return {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
-            user_query = session.scalar(
-                select(User)
-                .options(
-                    joinedload(User.roles)
-                    .subqueryload(Role.permissions)
-                    .options(joinedload(Permission.action), 
joinedload(Permission.resource))
-                )
-                .where(User.id == user.id)
+        if self._is_authorized(method=method, resource_type=RESOURCE_DAG, 
user=user):
+            # If user is authorized to access all DAGs, return all DAGs
+            return {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
+        if isinstance(user, AnonymousUser):
+            return set()
+        user_query = session.scalar(
+            select(User)
+            .options(
+                joinedload(User.roles)
+                .subqueryload(Role.permissions)
+                .options(joinedload(Permission.action), 
joinedload(Permission.resource))
             )
-            roles = user_query.roles
+            .where(User.id == user.id)
+        )
+        roles = user_query.roles
 
         map_fab_action_name_to_method_name = get_method_from_fab_action_map()
         resources = set()
@@ -387,7 +381,7 @@ class FabAuthManager(BaseAuthManager[User]):
                 action = permission.action.name
                 if (
                     action in map_fab_action_name_to_method_name
-                    and map_fab_action_name_to_method_name[action] in methods
+                    and map_fab_action_name_to_method_name[action] == method
                 ):
                     resource = permission.resource.name
                     if resource == permissions.RESOURCE_DAG:
diff --git 
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
 
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
index adbdfe14397..ef36355d2ad 100644
--- 
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++ 
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -973,12 +973,12 @@ class 
FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2):
     @staticmethod
     def get_readable_dag_ids(user=None) -> set[str]:
         """Get the DAG IDs readable by authenticated user."""
-        return get_auth_manager().get_permitted_dag_ids(methods=["GET"], 
user=user)
+        return get_auth_manager().get_permitted_dag_ids(user=user)
 
     @staticmethod
     def get_editable_dag_ids(user=None) -> set[str]:
         """Get the DAG IDs editable by authenticated user."""
-        return get_auth_manager().get_permitted_dag_ids(methods=["PUT"], 
user=user)
+        return get_auth_manager().get_permitted_dag_ids(method="PUT", 
user=user)
 
     def can_access_some_dags(self, action: str, dag_id: str | None = None) -> 
bool:
         """Check if user has read or write access to some dags."""
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py 
b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
index 399937e63a5..45f200e09e5 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
@@ -27,7 +27,8 @@ from flask import Flask, g
 
 from airflow.exceptions import AirflowConfigException, AirflowException
 from airflow.providers.fab.www.extensions.init_appbuilder import 
init_appbuilder
-from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user
+from airflow.providers.standard.operators.empty import EmptyOperator
+from unit.fab.auth_manager.api_endpoints.api_connexion_utils import 
create_user, delete_user
 
 try:
     from airflow.auth.managers.models.resource_details import AccessView, 
DagAccessEntity, DagDetails
@@ -449,6 +450,83 @@ class TestFabAuthManager:
         result = auth_manager.is_authorized_custom_view(method=method, 
resource_name=resource_name, user=user)
         assert result == expected_result
 
+    @pytest.mark.parametrize(
+        "method, user_permissions, expected_results",
+        [
+            # Scenario 1
+            # With global read permissions on Dags
+            (
+                "GET",
+                [(ACTION_CAN_READ, RESOURCE_DAG)],
+                {"test_dag1", "test_dag2"},
+            ),
+            # Scenario 2
+            # With global edit permissions on Dags
+            (
+                "PUT",
+                [(ACTION_CAN_EDIT, RESOURCE_DAG)],
+                {"test_dag1", "test_dag2"},
+            ),
+            # Scenario 3
+            # With DAG-specific permissions
+            (
+                "GET",
+                [(ACTION_CAN_READ, "DAG:test_dag1")],
+                {"test_dag1"},
+            ),
+            # Scenario 4
+            # With no permissions
+            (
+                "GET",
+                [],
+                set(),
+            ),
+            # Scenario 5
+            # With read permissions but edit is requested
+            (
+                "PUT",
+                [(ACTION_CAN_READ, RESOURCE_DAG)],
+                set(),
+            ),
+            # Scenario 7
+            # With read permissions but edit is requested
+            (
+                "PUT",
+                [(ACTION_CAN_READ, "DAG:test_dag1")],
+                set(),
+            ),
+            # Scenario 8
+            # With DAG-specific permissions
+            (
+                "PUT",
+                [(ACTION_CAN_EDIT, "DAG:test_dag1"), (ACTION_CAN_EDIT, 
"DAG:test_dag2")],
+                {"test_dag1", "test_dag2"},
+            ),
+        ],
+    )
+    def test_get_permitted_dag_ids(
+        self, method, user_permissions, expected_results, 
auth_manager_with_appbuilder, dag_maker, flask_app
+    ):
+        with dag_maker("test_dag1"):
+            EmptyOperator(task_id="task1")
+        with dag_maker("test_dag2"):
+            EmptyOperator(task_id="task1")
+
+        
auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag1")
+        
auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag2")
+
+        user = create_user(
+            flask_app,
+            username="username",
+            role_name="test",
+            permissions=user_permissions,
+        )
+
+        results = 
auth_manager_with_appbuilder.get_permitted_dag_ids(user=user, method=method)
+        assert results == expected_results
+
+        delete_user(flask_app, "username")
+
     @pytest.mark.db_test
     def test_security_manager_return_fab_security_manager_override(self, 
auth_manager_with_appbuilder):
         assert isinstance(auth_manager_with_appbuilder.security_manager, 
FabAirflowSecurityManagerOverride)
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_security.py 
b/providers/fab/tests/unit/fab/auth_manager/test_security.py
index 4ceca5cb3ac..a2aa51e3291 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_security.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_security.py
@@ -544,7 +544,7 @@ def 
test_dont_get_inaccessible_dag_ids_for_dag_resource_permission(
                 dag_id, access_control={role_name: permission_action}
             )
 
-            assert get_auth_manager().get_permitted_dag_ids(methods=["GET"], 
user=user) == set()
+            assert get_auth_manager().get_permitted_dag_ids(user=user) == set()
 
 
 def test_has_access(security_manager):
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index c228ad91584..a7bb0322d01 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -258,34 +258,23 @@ class TestBaseAuthManager:
         assert result == expected
 
     @pytest.mark.parametrize(
-        "access_all, access_per_dag, dag_ids, expected",
+        "access_per_dag, dag_ids, expected",
         [
-            # Access to all dags
-            (
-                True,
-                {},
-                ["dag1", "dag2"],
-                {"dag1", "dag2"},
-            ),
             # No access to any dag
             (
-                False,
                 {},
                 ["dag1", "dag2"],
                 set(),
             ),
             # Access to specific dags
             (
-                False,
                 {"dag1": True},
                 ["dag1", "dag2"],
                 {"dag1"},
             ),
         ],
     )
-    def test_get_permitted_dag_ids(
-        self, auth_manager, access_all: bool, access_per_dag: dict, dag_ids: 
list, expected: set
-    ):
+    def test_get_permitted_dag_ids(self, auth_manager, access_per_dag: dict, 
dag_ids: list, expected: set):
         def side_effect_func(
             *,
             method: ResourceMethod,
@@ -294,7 +283,7 @@ class TestBaseAuthManager:
             user: BaseAuthManagerUserTest | None = None,
         ):
             if not details:
-                return access_all
+                return False
             else:
                 return access_per_dag.get(details.id, False)
 

Reply via email to