This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f10c4314aab Fix and simplify `get_permitted_dag_ids` in auth manager
(#47458)
f10c4314aab is described below
commit f10c4314aab2ab98c94e6c277d9c8019eba3a9f6
Author: Vincent <[email protected]>
AuthorDate: Thu Mar 6 14:24:49 2025 -0500
Fix and simplify `get_permitted_dag_ids` in auth manager (#47458)
---
airflow/auth/managers/base_auth_manager.py | 32 +++------
.../amazon/aws/auth_manager/aws_auth_manager.py | 34 +++------
.../aws/auth_manager/test_aws_auth_manager.py | 56 ++++++++++++---
.../providers/fab/auth_manager/fab_auth_manager.py | 40 +++++------
.../fab/auth_manager/security_manager/override.py | 4 +-
.../unit/fab/auth_manager/test_fab_auth_manager.py | 80 +++++++++++++++++++++-
.../tests/unit/fab/auth_manager/test_security.py | 2 +-
tests/auth/managers/test_base_auth_manager.py | 17 +----
8 files changed, 167 insertions(+), 98 deletions(-)
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 0ca18db8121..f3b86600a05 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -34,7 +34,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
- from collections.abc import Container, Sequence
+ from collections.abc import Sequence
from fastapi import FastAPI
from sqlalchemy.orm import Session
@@ -331,7 +331,7 @@ class BaseAuthManager(Generic[T], LoggingMixin):
self,
*,
user: T,
- methods: Container[ResourceMethod] | None = None,
+ method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
"""
@@ -342,45 +342,31 @@ class BaseAuthManager(Generic[T], LoggingMixin):
implementation to provide a more efficient implementation.
:param user: the user
- :param methods: whether filter readable or writable
+ :param method: the method to filter on
:param session: the session
"""
dag_ids = {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
- return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods,
user=user)
+ return self.filter_permitted_dag_ids(dag_ids=dag_ids, method=method,
user=user)
def filter_permitted_dag_ids(
self,
*,
dag_ids: set[str],
user: T,
- methods: Container[ResourceMethod] | None = None,
+ method: ResourceMethod = "GET",
) -> set[str]:
"""
Filter readable or writable DAGs for user.
:param dag_ids: the list of DAG ids
:param user: the user
- :param methods: whether filter readable or writable
+ :param method: the method to filter on
"""
- if not methods:
- methods = ["PUT", "GET"]
- if ("GET" in methods and self.is_authorized_dag(method="GET",
user=user)) or (
- "PUT" in methods and self.is_authorized_dag(method="PUT",
user=user)
- ):
- # If user is authorized to read/edit all DAGs, return all DAGs
- return dag_ids
+ def _is_permitted_dag_id(method: ResourceMethod, dag_id: str):
+ return self.is_authorized_dag(method=method,
details=DagDetails(id=dag_id), user=user)
- def _is_permitted_dag_id(method: ResourceMethod, methods:
Container[ResourceMethod], dag_id: str):
- return method in methods and self.is_authorized_dag(
- method=method, details=DagDetails(id=dag_id), user=user
- )
-
- return {
- dag_id
- for dag_id in dag_ids
- if _is_permitted_dag_id("GET", methods, dag_id) or
_is_permitted_dag_id("PUT", methods, dag_id)
- }
+ return {dag_id for dag_id in dag_ids if _is_permitted_dag_id(method,
dag_id)}
@staticmethod
def get_cli_commands() -> list[CLICommand]:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 6c992438a96..04b547b2979 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import argparse
from collections import defaultdict
-from collections.abc import Container, Sequence
+from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, cast
@@ -283,23 +283,18 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
*,
dag_ids: set[str],
user: AwsAuthManagerUser,
- methods: Container[ResourceMethod] | None = None,
+ method: ResourceMethod = "GET",
):
- if not methods:
- methods = ["PUT", "GET"]
-
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] =
defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for dag_id in dag_ids:
- for method in ["GET", "PUT"]:
- if method in methods:
- request: IsAuthorizedRequest = {
- "method": cast("ResourceMethod", method),
- "entity_type": AvpEntities.DAG,
- "entity_id": dag_id,
- }
- requests[dag_id][cast("ResourceMethod", method)] = request
- requests_list.append(request)
+ request: IsAuthorizedRequest = {
+ "method": method,
+ "entity_type": AvpEntities.DAG,
+ "entity_id": dag_id,
+ }
+ requests[dag_id][method] = request
+ requests_list.append(request)
batch_is_authorized_results =
self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
@@ -311,16 +306,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
)
return result["decision"] == "ALLOW"
- return {
- dag_id
- for dag_id in dag_ids
- if (
- "GET" in methods
- and _has_access_to_dag(requests[dag_id]["GET"])
- or "PUT" in methods
- and _has_access_to_dag(requests[dag_id]["PUT"])
- )
- }
+ return {dag_id for dag_id in dag_ids if
_has_access_to_dag(requests[dag_id][method])}
def get_url_login(self, **kwargs) -> str:
return f"{self.apiserver_endpoint}/auth/login"
diff --git
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
index be3485d40c1..45935e4b686 100644
---
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
+++
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -445,26 +445,62 @@ class TestAwsAuthManager:
assert result
@pytest.mark.parametrize(
- "methods, user",
+ "method, user, expected_result",
[
- (None, AwsAuthManagerUser(user_id="test_user_id", groups=[])),
- (["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id",
groups=[])),
+ ("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]),
{"dag_1"}),
+ ("PUT", AwsAuthManagerUser(user_id="test_user_id1", groups=[]),
set()),
+ ("GET", AwsAuthManagerUser(user_id="test_user_id2", groups=[]),
set()),
+ ("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]),
{"dag_2"}),
],
)
- def test_filter_permitted_dag_ids(self, methods, user, auth_manager,
test_user):
+ def test_filter_permitted_dag_ids(self, method, user, auth_manager,
test_user, expected_result):
dag_ids = {"dag_1", "dag_2"}
+ # test_user_id1 has GET permissions on dag_1
+ # test_user_id2 has PUT permissions on dag_2
batch_is_authorized_output = [
{
"request": {
- "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId":
"Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId":
"dag_1"},
},
+ "decision": "ALLOW",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id1"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dag.PUT"},
+ "resource": {"entityType": "Airflow::Dag", "entityId":
"dag_1"},
+ },
"decision": "DENY",
},
{
"request": {
- "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id1"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dag.GET"},
+ "resource": {"entityType": "Airflow::Dag", "entityId":
"dag_2"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id1"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dag.PUT"},
+ "resource": {"entityType": "Airflow::Dag", "entityId":
"dag_2"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id2"},
+ "action": {"actionType": "Airflow::Action", "actionId":
"Dag.GET"},
+ "resource": {"entityType": "Airflow::Dag", "entityId":
"dag_1"},
+ },
+ "decision": "DENY",
+ },
+ {
+ "request": {
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId":
"Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId":
"dag_1"},
},
@@ -472,7 +508,7 @@ class TestAwsAuthManager:
},
{
"request": {
- "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId":
"Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId":
"dag_2"},
},
@@ -480,7 +516,7 @@ class TestAwsAuthManager:
},
{
"request": {
- "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id"},
+ "principal": {"entityType": "Airflow::User", "entityId":
"test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId":
"Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId":
"dag_2"},
},
@@ -493,12 +529,12 @@ class TestAwsAuthManager:
result = auth_manager.filter_permitted_dag_ids(
dag_ids=dag_ids,
- methods=methods,
+ method=method,
user=user,
)
auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
- assert result == {"dag_2"}
+ assert result == expected_result
def test_get_url_login(self, auth_manager):
result = auth_manager.get_url_login()
diff --git
a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index b416d31e212..eb3ef724ab6 100644
--- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import argparse
-from collections.abc import Container
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -58,6 +57,7 @@ from
airflow.providers.fab.auth_manager.cli_commands.definition import (
USERS_COMMANDS,
)
from airflow.providers.fab.auth_manager.models import Permission, Role, User
+from airflow.providers.fab.auth_manager.models.anonymous_user import
AnonymousUser
from airflow.providers.fab.www.app import create_app
from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
from airflow.providers.fab.www.extensions.init_views import
_CustomErrorRequestBodyValidator, _LazyResolver
@@ -355,30 +355,24 @@ class FabAuthManager(BaseAuthManager[User]):
self,
*,
user: User,
- methods: Container[ResourceMethod] | None = None,
+ method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
- if not methods:
- methods = ["PUT", "GET"]
-
- if not self.is_logged_in():
- roles = user.roles
- else:
- if ("GET" in methods and self.is_authorized_dag(method="GET",
user=user)) or (
- "PUT" in methods and self.is_authorized_dag(method="PUT",
user=user)
- ):
- # If user is authorized to read/edit all DAGs, return all DAGs
- return {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
- user_query = session.scalar(
- select(User)
- .options(
- joinedload(User.roles)
- .subqueryload(Role.permissions)
- .options(joinedload(Permission.action),
joinedload(Permission.resource))
- )
- .where(User.id == user.id)
+ if self._is_authorized(method=method, resource_type=RESOURCE_DAG,
user=user):
+ # If user is authorized to access all DAGs, return all DAGs
+ return {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
+ if isinstance(user, AnonymousUser):
+ return set()
+ user_query = session.scalar(
+ select(User)
+ .options(
+ joinedload(User.roles)
+ .subqueryload(Role.permissions)
+ .options(joinedload(Permission.action),
joinedload(Permission.resource))
)
- roles = user_query.roles
+ .where(User.id == user.id)
+ )
+ roles = user_query.roles
map_fab_action_name_to_method_name = get_method_from_fab_action_map()
resources = set()
@@ -387,7 +381,7 @@ class FabAuthManager(BaseAuthManager[User]):
action = permission.action.name
if (
action in map_fab_action_name_to_method_name
- and map_fab_action_name_to_method_name[action] in methods
+ and map_fab_action_name_to_method_name[action] == method
):
resource = permission.resource.name
if resource == permissions.RESOURCE_DAG:
diff --git
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
index adbdfe14397..ef36355d2ad 100644
---
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -973,12 +973,12 @@ class
FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2):
@staticmethod
def get_readable_dag_ids(user=None) -> set[str]:
"""Get the DAG IDs readable by authenticated user."""
- return get_auth_manager().get_permitted_dag_ids(methods=["GET"],
user=user)
+ return get_auth_manager().get_permitted_dag_ids(user=user)
@staticmethod
def get_editable_dag_ids(user=None) -> set[str]:
"""Get the DAG IDs editable by authenticated user."""
- return get_auth_manager().get_permitted_dag_ids(methods=["PUT"],
user=user)
+ return get_auth_manager().get_permitted_dag_ids(method="PUT",
user=user)
def can_access_some_dags(self, action: str, dag_id: str | None = None) ->
bool:
"""Check if user has read or write access to some dags."""
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
index 399937e63a5..45f200e09e5 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py
@@ -27,7 +27,8 @@ from flask import Flask, g
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.providers.fab.www.extensions.init_appbuilder import
init_appbuilder
-from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user
+from airflow.providers.standard.operators.empty import EmptyOperator
+from unit.fab.auth_manager.api_endpoints.api_connexion_utils import
create_user, delete_user
try:
from airflow.auth.managers.models.resource_details import AccessView,
DagAccessEntity, DagDetails
@@ -449,6 +450,83 @@ class TestFabAuthManager:
result = auth_manager.is_authorized_custom_view(method=method,
resource_name=resource_name, user=user)
assert result == expected_result
+ @pytest.mark.parametrize(
+ "method, user_permissions, expected_results",
+ [
+ # Scenario 1
+ # With global read permissions on Dags
+ (
+ "GET",
+ [(ACTION_CAN_READ, RESOURCE_DAG)],
+ {"test_dag1", "test_dag2"},
+ ),
+ # Scenario 2
+ # With global edit permissions on Dags
+ (
+ "PUT",
+ [(ACTION_CAN_EDIT, RESOURCE_DAG)],
+ {"test_dag1", "test_dag2"},
+ ),
+ # Scenario 3
+ # With DAG-specific permissions
+ (
+ "GET",
+ [(ACTION_CAN_READ, "DAG:test_dag1")],
+ {"test_dag1"},
+ ),
+ # Scenario 4
+ # With no permissions
+ (
+ "GET",
+ [],
+ set(),
+ ),
+ # Scenario 5
+ # With read permissions but edit is requested
+ (
+ "PUT",
+ [(ACTION_CAN_READ, RESOURCE_DAG)],
+ set(),
+ ),
+ # Scenario 7
+ # With read permissions but edit is requested
+ (
+ "PUT",
+ [(ACTION_CAN_READ, "DAG:test_dag1")],
+ set(),
+ ),
+ # Scenario 8
+ # With DAG-specific permissions
+ (
+ "PUT",
+ [(ACTION_CAN_EDIT, "DAG:test_dag1"), (ACTION_CAN_EDIT,
"DAG:test_dag2")],
+ {"test_dag1", "test_dag2"},
+ ),
+ ],
+ )
+ def test_get_permitted_dag_ids(
+ self, method, user_permissions, expected_results,
auth_manager_with_appbuilder, dag_maker, flask_app
+ ):
+ with dag_maker("test_dag1"):
+ EmptyOperator(task_id="task1")
+ with dag_maker("test_dag2"):
+ EmptyOperator(task_id="task1")
+
+
auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag1")
+
auth_manager_with_appbuilder.security_manager.sync_perm_for_dag("test_dag2")
+
+ user = create_user(
+ flask_app,
+ username="username",
+ role_name="test",
+ permissions=user_permissions,
+ )
+
+ results =
auth_manager_with_appbuilder.get_permitted_dag_ids(user=user, method=method)
+ assert results == expected_results
+
+ delete_user(flask_app, "username")
+
@pytest.mark.db_test
def test_security_manager_return_fab_security_manager_override(self,
auth_manager_with_appbuilder):
assert isinstance(auth_manager_with_appbuilder.security_manager,
FabAirflowSecurityManagerOverride)
diff --git a/providers/fab/tests/unit/fab/auth_manager/test_security.py
b/providers/fab/tests/unit/fab/auth_manager/test_security.py
index 4ceca5cb3ac..a2aa51e3291 100644
--- a/providers/fab/tests/unit/fab/auth_manager/test_security.py
+++ b/providers/fab/tests/unit/fab/auth_manager/test_security.py
@@ -544,7 +544,7 @@ def
test_dont_get_inaccessible_dag_ids_for_dag_resource_permission(
dag_id, access_control={role_name: permission_action}
)
- assert get_auth_manager().get_permitted_dag_ids(methods=["GET"],
user=user) == set()
+ assert get_auth_manager().get_permitted_dag_ids(user=user) == set()
def test_has_access(security_manager):
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index c228ad91584..a7bb0322d01 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -258,34 +258,23 @@ class TestBaseAuthManager:
assert result == expected
@pytest.mark.parametrize(
- "access_all, access_per_dag, dag_ids, expected",
+ "access_per_dag, dag_ids, expected",
[
- # Access to all dags
- (
- True,
- {},
- ["dag1", "dag2"],
- {"dag1", "dag2"},
- ),
# No access to any dag
(
- False,
{},
["dag1", "dag2"],
set(),
),
# Access to specific dags
(
- False,
{"dag1": True},
["dag1", "dag2"],
{"dag1"},
),
],
)
- def test_get_permitted_dag_ids(
- self, auth_manager, access_all: bool, access_per_dag: dict, dag_ids:
list, expected: set
- ):
+ def test_get_permitted_dag_ids(self, auth_manager, access_per_dag: dict,
dag_ids: list, expected: set):
def side_effect_func(
*,
method: ResourceMethod,
@@ -294,7 +283,7 @@ class TestBaseAuthManager:
user: BaseAuthManagerUserTest | None = None,
):
if not details:
- return access_all
+ return False
else:
return access_per_dag.get(details.id, False)