This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fc28db1727 Add batch is_authorized APIs to auth manager (#35562)
fc28db1727 is described below

commit fc28db1727d18d43cf90f435f10bd0450ab1ce25
Author: Vincent <[email protected]>
AuthorDate: Fri Nov 17 08:46:12 2023 -0500

    Add batch is_authorized APIs to auth manager (#35562)
---
 .../endpoints/task_instance_endpoint.py            | 21 ++---
 airflow/auth/managers/base_auth_manager.py         | 91 ++++++++++++++++++++-
 airflow/auth/managers/models/batch_apis.py         | 64 +++++++++++++++
 airflow/www/auth.py                                | 70 ++++++++--------
 .../endpoints/test_task_instance_endpoint.py       |  2 +-
 tests/auth/managers/test_base_auth_manager.py      | 92 ++++++++++++++++++++--
 tests/www/test_auth.py                             | 14 ++--
 7 files changed, 302 insertions(+), 52 deletions(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 946f8fc0dc..b5e7c273bb 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -16,7 +16,7 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any, Iterable, TypeVar
+from typing import TYPE_CHECKING, Any, Iterable, Sequence, TypeVar
 
 from flask import g
 from marshmallow import ValidationError
@@ -59,6 +59,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql import ClauseElement, Select
 
     from airflow.api_connexion.types import APIResponse
+    from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
 
 T = TypeVar("T")
 
@@ -378,14 +379,16 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
         raise BadRequest(detail=str(err.messages))
     dag_ids = data["dag_ids"]
     if dag_ids:
-        cannot_access_dag_ids = set()
-        for id in dag_ids:
-            if not get_auth_manager().is_authorized_dag(method="GET", 
details=DagDetails(id=id), user=g.user):
-                cannot_access_dag_ids.add(id)
-        if cannot_access_dag_ids:
-            raise PermissionDenied(
-                detail=f"User not allowed to access these DAGs: 
{list(cannot_access_dag_ids)}"
-            )
+        requests: Sequence[IsAuthorizedDagRequest] = [
+            {
+                "method": "GET",
+                "details": DagDetails(id=id),
+                "user": g.user,
+            }
+            for id in dag_ids
+        ]
+        if not get_auth_manager().batch_is_authorized_dag(requests):
+            raise PermissionDenied(detail=f"User not allowed to access some of 
these DAGs: {list(dag_ids)}")
     else:
         dag_ids = 
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
 
diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index 3b25098ab8..c158ea8481 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from abc import abstractmethod
 from functools import cached_property
-from typing import TYPE_CHECKING, Container, Literal
+from typing import TYPE_CHECKING, Container, Literal, Sequence
 
 from sqlalchemy import select
 
@@ -37,6 +37,12 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.auth.managers.models.base_user import BaseUser
+    from airflow.auth.managers.models.batch_apis import (
+        IsAuthorizedConnectionRequest,
+        IsAuthorizedDagRequest,
+        IsAuthorizedPoolRequest,
+        IsAuthorizedVariableRequest,
+    )
     from airflow.auth.managers.models.resource_details import (
         AccessView,
         ConfigurationDetails,
@@ -250,6 +256,89 @@ class BaseAuthManager(LoggingMixin):
         """
         raise AirflowException(f"The resource `{fab_resource_name}` does not 
exist in the environment.")
 
+    def batch_is_authorized_dag(
+        self,
+        requests: Sequence[IsAuthorizedDagRequest],
+    ) -> bool:
+        """
+        Batch version of ``is_authorized_dag``.
+
+        By default, calls individually the ``is_authorized_dag`` API on each 
item in the list of requests.
+        Can lead to some poor performance. It is recommended to override this 
method in the auth manager
+        implementation to provide a more efficient implementation.
+
+        :param requests: a list of requests containing the parameters for 
``is_authorized_dag``
+        """
+        return all(
+            self.is_authorized_dag(
+                method=request["method"],
+                access_entity=request.get("access_entity"),
+                details=request.get("details"),
+                user=request.get("user"),
+            )
+            for request in requests
+        )
+
+    def batch_is_authorized_connection(
+        self,
+        requests: Sequence[IsAuthorizedConnectionRequest],
+    ) -> bool:
+        """
+        Batch version of ``is_authorized_connection``.
+
+        By default, calls individually the ``is_authorized_connection`` API on 
each item in the list of
+        requests. Can lead to some poor performance. It is recommended to 
override this method in the auth
+        manager implementation to provide a more efficient implementation.
+
+        :param requests: a list of requests containing the parameters for 
``is_authorized_connection``
+        """
+        return all(
+            self.is_authorized_connection(
+                method=request["method"], details=request.get("details"), 
user=request.get("user")
+            )
+            for request in requests
+        )
+
+    def batch_is_authorized_pool(
+        self,
+        requests: Sequence[IsAuthorizedPoolRequest],
+    ) -> bool:
+        """
+        Batch version of ``is_authorized_pool``.
+
+        By default, calls individually the ``is_authorized_pool`` API on each 
item in the list of
+        requests. Can lead to some poor performance. It is recommended to 
override this method in the auth
+        manager implementation to provide a more efficient implementation.
+
+        :param requests: a list of requests containing the parameters for 
``is_authorized_pool``
+        """
+        return all(
+            self.is_authorized_pool(
+                method=request["method"], details=request.get("details"), 
user=request.get("user")
+            )
+            for request in requests
+        )
+
+    def batch_is_authorized_variable(
+        self,
+        requests: Sequence[IsAuthorizedVariableRequest],
+    ) -> bool:
+        """
+        Batch version of ``is_authorized_variable``.
+
+        By default, calls individually the ``is_authorized_variable`` API on 
each item in the list of
+        requests. Can lead to some poor performance. It is recommended to 
override this method in the auth
+        manager implementation to provide a more efficient implementation.
+
+        :param requests: a list of requests containing the parameters for 
``is_authorized_variable``
+        """
+        return all(
+            self.is_authorized_variable(
+                method=request["method"], details=request.get("details"), 
user=request.get("user")
+            )
+            for request in requests
+        )
+
     @provide_session
     def get_permitted_dag_ids(
         self,
diff --git a/airflow/auth/managers/models/batch_apis.py 
b/airflow/auth/managers/models/batch_apis.py
new file mode 100644
index 0000000000..7cb16339a7
--- /dev/null
+++ b/airflow/auth/managers/models/batch_apis.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypedDict
+
+if TYPE_CHECKING:
+    from airflow.auth.managers.base_auth_manager import ResourceMethod
+    from airflow.auth.managers.models.base_user import BaseUser
+    from airflow.auth.managers.models.resource_details import (
+        ConnectionDetails,
+        DagAccessEntity,
+        DagDetails,
+        PoolDetails,
+        VariableDetails,
+    )
+
+
+class IsAuthorizedConnectionRequest(TypedDict, total=False):
+    """Represent the parameters of ``is_authorized_connection`` API in the 
auth manager."""
+
+    method: ResourceMethod
+    details: ConnectionDetails | None
+    user: BaseUser | None
+
+
+class IsAuthorizedDagRequest(TypedDict, total=False):
+    """Represent the parameters of ``is_authorized_dag`` API in the auth 
manager."""
+
+    method: ResourceMethod
+    access_entity: DagAccessEntity | None
+    details: DagDetails | None
+    user: BaseUser | None
+
+
+class IsAuthorizedPoolRequest(TypedDict, total=False):
+    """Represent the parameters of ``is_authorized_pool`` API in the auth 
manager."""
+
+    method: ResourceMethod
+    details: PoolDetails | None
+    user: BaseUser | None
+
+
+class IsAuthorizedVariableRequest(TypedDict, total=False):
+    """Represent the parameters of ``is_authorized_variable`` API in the auth 
manager."""
+
+    method: ResourceMethod
+    details: VariableDetails | None
+    user: BaseUser | None
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index e285d517ba..1ad6e6dab5 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -45,6 +45,12 @@ from airflow.www.extensions.init_auth_manager import 
get_auth_manager
 
 if TYPE_CHECKING:
     from airflow.auth.managers.base_auth_manager import ResourceMethod
+    from airflow.auth.managers.models.batch_apis import (
+        IsAuthorizedConnectionRequest,
+        IsAuthorizedDagRequest,
+        IsAuthorizedPoolRequest,
+        IsAuthorizedVariableRequest,
+    )
     from airflow.models import DagRun, Pool, SlaMiss, TaskInstance, Variable
     from airflow.models.connection import Connection
     from airflow.models.xcom import BaseXCom
@@ -190,15 +196,14 @@ def has_access_connection(method: ResourceMethod) -> 
Callable[[T], T]:
         @wraps(func)
         def decorated(*args, **kwargs):
             connections: set[Connection] = set(args[1])
-            connections_details = [
-                ConnectionDetails(conn_id=connection.conn_id) for connection 
in connections
+            requests: Sequence[IsAuthorizedConnectionRequest] = [
+                {
+                    "method": method,
+                    "details": ConnectionDetails(conn_id=connection.conn_id),
+                }
+                for connection in connections
             ]
-            is_authorized = all(
-                [
-                    get_auth_manager().is_authorized_connection(method=method, 
details=connection_details)
-                    for connection_details in connections_details
-                ]
-            )
+            is_authorized = 
get_auth_manager().batch_is_authorized_connection(requests)
             return _has_access(
                 is_authorized=is_authorized,
                 func=func,
@@ -265,15 +270,16 @@ def has_access_dag_entities(method: ResourceMethod, 
access_entity: DagAccessEnti
         @wraps(func)
         def decorated(*args, **kwargs):
             items: set[SlaMiss | BaseXCom | DagRun | TaskInstance] = 
set(args[1])
-            dags_details = [DagDetails(id=item.dag_id) for item in items if 
item is not None]
-            is_authorized = all(
-                [
-                    get_auth_manager().is_authorized_dag(
-                        method=method, access_entity=access_entity, 
details=dag_details
-                    )
-                    for dag_details in dags_details
-                ]
-            )
+            requests: Sequence[IsAuthorizedDagRequest] = [
+                {
+                    "method": method,
+                    "access_entity": access_entity,
+                    "details": DagDetails(id=item.dag_id),
+                }
+                for item in items
+                if item is not None
+            ]
+            is_authorized = 
get_auth_manager().batch_is_authorized_dag(requests)
             return _has_access(
                 is_authorized=is_authorized,
                 func=func,
@@ -296,13 +302,14 @@ def has_access_pool(method: ResourceMethod) -> 
Callable[[T], T]:
         @wraps(func)
         def decorated(*args, **kwargs):
             pools: set[Pool] = set(args[1])
-            pools_details = [PoolDetails(name=pool.pool) for pool in pools]
-            is_authorized = all(
-                [
-                    get_auth_manager().is_authorized_pool(method=method, 
details=pool_details)
-                    for pool_details in pools_details
-                ]
-            )
+            requests: Sequence[IsAuthorizedPoolRequest] = [
+                {
+                    "method": method,
+                    "details": PoolDetails(name=pool.pool),
+                }
+                for pool in pools
+            ]
+            is_authorized = 
get_auth_manager().batch_is_authorized_pool(requests)
             return _has_access(
                 is_authorized=is_authorized,
                 func=func,
@@ -324,13 +331,14 @@ def has_access_variable(method: ResourceMethod) -> 
Callable[[T], T]:
                 is_authorized = 
get_auth_manager().is_authorized_variable(method=method)
             else:
                 variables: set[Variable] = set(args[1])
-                variables_details = [VariableDetails(key=variable.key) for 
variable in variables]
-                is_authorized = all(
-                    [
-                        
get_auth_manager().is_authorized_variable(method=method, 
details=variable_details)
-                        for variable_details in variables_details
-                    ]
-                )
+                requests: Sequence[IsAuthorizedVariableRequest] = [
+                    {
+                        "method": method,
+                        "details": VariableDetails(key=variable.key),
+                    }
+                    for variable in variables
+                ]
+                is_authorized = 
get_auth_manager().batch_is_authorized_variable(requests)
             return _has_access(
                 is_authorized=is_authorized,
                 func=func,
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 9f913ac9fe..3125031f56 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -994,7 +994,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         )
         assert response.status_code == 403
         assert response.json == {
-            "detail": "User not allowed to access these DAGs: 
['example_skip_dag']",
+            "detail": "User not allowed to access some of these DAGs: 
['example_python_operator', 'example_skip_dag']",
             "status": 403,
             "title": "Forbidden",
             "type": EXCEPTIONS_LINK_MAP[403],
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index f009fb055b..9b1b5659f4 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -17,12 +17,18 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING
-from unittest.mock import MagicMock, Mock
+from unittest.mock import MagicMock, Mock, patch
 
 import pytest
 from flask import Flask
 
 from airflow.auth.managers.base_auth_manager import BaseAuthManager, 
ResourceMethod
+from airflow.auth.managers.models.resource_details import (
+    ConnectionDetails,
+    DagDetails,
+    PoolDetails,
+    VariableDetails,
+)
 from airflow.exceptions import AirflowException
 from airflow.security import permissions
 from airflow.www.extensions.init_appbuilder import init_appbuilder
@@ -33,12 +39,8 @@ if TYPE_CHECKING:
     from airflow.auth.managers.models.resource_details import (
         AccessView,
         ConfigurationDetails,
-        ConnectionDetails,
         DagAccessEntity,
-        DagDetails,
         DatasetDetails,
-        PoolDetails,
-        VariableDetails,
     )
 
 
@@ -143,6 +145,86 @@ class TestBaseAuthManager:
                 fab_resource_name=permissions.RESOURCE_MY_PASSWORD,
             )
 
+    @pytest.mark.parametrize(
+        "return_values, expected",
+        [
+            ([False, False], False),
+            ([True, False], False),
+            ([True, True], True),
+        ],
+    )
+    @patch.object(EmptyAuthManager, "is_authorized_dag")
+    def test_batch_is_authorized_dag(self, mock_is_authorized_dag, 
auth_manager, return_values, expected):
+        mock_is_authorized_dag.side_effect = return_values
+        result = auth_manager.batch_is_authorized_dag(
+            [
+                {"method": "GET", "details": DagDetails(id="dag1")},
+                {"method": "GET", "details": DagDetails(id="dag2")},
+            ]
+        )
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        "return_values, expected",
+        [
+            ([False, False], False),
+            ([True, False], False),
+            ([True, True], True),
+        ],
+    )
+    @patch.object(EmptyAuthManager, "is_authorized_connection")
+    def test_batch_is_authorized_connection(
+        self, mock_is_authorized_connection, auth_manager, return_values, 
expected
+    ):
+        mock_is_authorized_connection.side_effect = return_values
+        result = auth_manager.batch_is_authorized_connection(
+            [
+                {"method": "GET", "details": 
ConnectionDetails(conn_id="conn1")},
+                {"method": "GET", "details": 
ConnectionDetails(conn_id="conn2")},
+            ]
+        )
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        "return_values, expected",
+        [
+            ([False, False], False),
+            ([True, False], False),
+            ([True, True], True),
+        ],
+    )
+    @patch.object(EmptyAuthManager, "is_authorized_pool")
+    def test_batch_is_authorized_pool(self, mock_is_authorized_pool, 
auth_manager, return_values, expected):
+        mock_is_authorized_pool.side_effect = return_values
+        result = auth_manager.batch_is_authorized_pool(
+            [
+                {"method": "GET", "details": PoolDetails(name="pool1")},
+                {"method": "GET", "details": PoolDetails(name="pool2")},
+            ]
+        )
+        assert result == expected
+
+    @pytest.mark.parametrize(
+        "return_values, expected",
+        [
+            ([False, False], False),
+            ([True, False], False),
+            ([True, True], True),
+        ],
+    )
+    @patch.object(EmptyAuthManager, "is_authorized_variable")
+    def test_batch_is_authorized_variable(
+        self, mock_is_authorized_variable, auth_manager, return_values, 
expected
+    ):
+        mock_is_authorized_variable.side_effect = return_values
+        result = auth_manager.batch_is_authorized_variable(
+            [
+                {"method": "GET", "details": VariableDetails(key="var1")},
+                {"method": "GET", "details": VariableDetails(key="var2")},
+            ]
+        )
+        assert result == expected
+
     @pytest.mark.db_test
     def test_security_manager_return_default_security_manager(self, 
auth_manager_with_appbuilder):
         assert isinstance(auth_manager_with_appbuilder.security_manager, 
AirflowSecurityManagerV2)
diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py
index 83fb0eebdf..dd0188974b 100644
--- a/tests/www/test_auth.py
+++ b/tests/www/test_auth.py
@@ -115,9 +115,13 @@ class TestHasAccessNoDetails:
 @pytest.mark.parametrize(
     "decorator_name, is_authorized_method_name, items",
     [
-        ("has_access_connection", "is_authorized_connection", 
[Connection("conn_1"), Connection("conn_2")]),
-        ("has_access_pool", "is_authorized_pool", [Pool(pool="pool_1"), 
Pool(pool="pool_2")]),
-        ("has_access_variable", "is_authorized_variable", [Variable("var_1"), 
Variable("var_2")]),
+        (
+            "has_access_connection",
+            "batch_is_authorized_connection",
+            [Connection("conn_1"), Connection("conn_2")],
+        ),
+        ("has_access_pool", "batch_is_authorized_pool", [Pool(pool="pool_1"), 
Pool(pool="pool_2")]),
+        ("has_access_variable", "batch_is_authorized_variable", 
[Variable("var_1"), Variable("var_2")]),
     ],
 )
 class TestHasAccessWithDetails:
@@ -181,7 +185,7 @@ class TestHasAccessDagEntities:
     @patch("airflow.www.auth.get_auth_manager")
     def test_has_access_dag_entities_when_authorized(self, 
mock_get_auth_manager, dag_access_entity):
         auth_manager = Mock()
-        auth_manager.is_authorized_dag.return_value = True
+        auth_manager.batch_is_authorized_dag.return_value = True
         mock_get_auth_manager.return_value = auth_manager
         items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]
 
@@ -194,7 +198,7 @@ class TestHasAccessDagEntities:
     @patch("airflow.www.auth.get_auth_manager")
     def test_has_access_dag_entities_when_unauthorized(self, 
mock_get_auth_manager, app, dag_access_entity):
         auth_manager = Mock()
-        auth_manager.is_authorized_dag.return_value = False
+        auth_manager.batch_is_authorized_dag.return_value = False
         mock_get_auth_manager.return_value = auth_manager
         items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]
 

Reply via email to