This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fc28db1727 Add batch is_authorized APIs to auth manager (#35562)
fc28db1727 is described below
commit fc28db1727d18d43cf90f435f10bd0450ab1ce25
Author: Vincent <[email protected]>
AuthorDate: Fri Nov 17 08:46:12 2023 -0500
Add batch is_authorized APIs to auth manager (#35562)
---
.../endpoints/task_instance_endpoint.py | 21 ++---
airflow/auth/managers/base_auth_manager.py | 91 ++++++++++++++++++++-
airflow/auth/managers/models/batch_apis.py | 64 +++++++++++++++
airflow/www/auth.py | 70 ++++++++--------
.../endpoints/test_task_instance_endpoint.py | 2 +-
tests/auth/managers/test_base_auth_manager.py | 92 ++++++++++++++++++++--
tests/www/test_auth.py | 14 ++--
7 files changed, 302 insertions(+), 52 deletions(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 946f8fc0dc..b5e7c273bb 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Iterable, TypeVar
+from typing import TYPE_CHECKING, Any, Iterable, Sequence, TypeVar
from flask import g
from marshmallow import ValidationError
@@ -59,6 +59,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql import ClauseElement, Select
from airflow.api_connexion.types import APIResponse
+ from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
T = TypeVar("T")
@@ -378,14 +379,16 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
raise BadRequest(detail=str(err.messages))
dag_ids = data["dag_ids"]
if dag_ids:
- cannot_access_dag_ids = set()
- for id in dag_ids:
- if not get_auth_manager().is_authorized_dag(method="GET",
details=DagDetails(id=id), user=g.user):
- cannot_access_dag_ids.add(id)
- if cannot_access_dag_ids:
- raise PermissionDenied(
- detail=f"User not allowed to access these DAGs:
{list(cannot_access_dag_ids)}"
- )
+ requests: Sequence[IsAuthorizedDagRequest] = [
+ {
+ "method": "GET",
+ "details": DagDetails(id=id),
+ "user": g.user,
+ }
+ for id in dag_ids
+ ]
+ if not get_auth_manager().batch_is_authorized_dag(requests):
+ raise PermissionDenied(detail=f"User not allowed to access some of
these DAGs: {list(dag_ids)}")
else:
dag_ids =
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 3b25098ab8..c158ea8481 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from abc import abstractmethod
from functools import cached_property
-from typing import TYPE_CHECKING, Container, Literal
+from typing import TYPE_CHECKING, Container, Literal, Sequence
from sqlalchemy import select
@@ -37,6 +37,12 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.auth.managers.models.base_user import BaseUser
+ from airflow.auth.managers.models.batch_apis import (
+ IsAuthorizedConnectionRequest,
+ IsAuthorizedDagRequest,
+ IsAuthorizedPoolRequest,
+ IsAuthorizedVariableRequest,
+ )
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
@@ -250,6 +256,89 @@ class BaseAuthManager(LoggingMixin):
"""
raise AirflowException(f"The resource `{fab_resource_name}` does not
exist in the environment.")
+ def batch_is_authorized_dag(
+ self,
+ requests: Sequence[IsAuthorizedDagRequest],
+ ) -> bool:
+ """
+ Batch version of ``is_authorized_dag``.
+
+ By default, calls individually the ``is_authorized_dag`` API on each
item in the list of requests.
+ Can lead to some poor performance. It is recommended to override this
method in the auth manager
+ implementation to provide a more efficient implementation.
+
+ :param requests: a list of requests containing the parameters for
``is_authorized_dag``
+ """
+ return all(
+ self.is_authorized_dag(
+ method=request["method"],
+ access_entity=request.get("access_entity"),
+ details=request.get("details"),
+ user=request.get("user"),
+ )
+ for request in requests
+ )
+
+ def batch_is_authorized_connection(
+ self,
+ requests: Sequence[IsAuthorizedConnectionRequest],
+ ) -> bool:
+ """
+ Batch version of ``is_authorized_connection``.
+
+ By default, calls individually the ``is_authorized_connection`` API on
each item in the list of
+ requests. Can lead to some poor performance. It is recommended to
override this method in the auth
+ manager implementation to provide a more efficient implementation.
+
+ :param requests: a list of requests containing the parameters for
``is_authorized_connection``
+ """
+ return all(
+ self.is_authorized_connection(
+ method=request["method"], details=request.get("details"),
user=request.get("user")
+ )
+ for request in requests
+ )
+
+ def batch_is_authorized_pool(
+ self,
+ requests: Sequence[IsAuthorizedPoolRequest],
+ ) -> bool:
+ """
+ Batch version of ``is_authorized_pool``.
+
+ By default, calls individually the ``is_authorized_pool`` API on each
item in the list of
+ requests. Can lead to some poor performance. It is recommended to
override this method in the auth
+ manager implementation to provide a more efficient implementation.
+
+ :param requests: a list of requests containing the parameters for
``is_authorized_pool``
+ """
+ return all(
+ self.is_authorized_pool(
+ method=request["method"], details=request.get("details"),
user=request.get("user")
+ )
+ for request in requests
+ )
+
+ def batch_is_authorized_variable(
+ self,
+ requests: Sequence[IsAuthorizedVariableRequest],
+ ) -> bool:
+ """
+ Batch version of ``is_authorized_variable``.
+
+ By default, calls individually the ``is_authorized_variable`` API on
each item in the list of
+ requests. Can lead to some poor performance. It is recommended to
override this method in the auth
+ manager implementation to provide a more efficient implementation.
+
+ :param requests: a list of requests containing the parameters for
``is_authorized_variable``
+ """
+ return all(
+ self.is_authorized_variable(
+ method=request["method"], details=request.get("details"),
user=request.get("user")
+ )
+ for request in requests
+ )
+
@provide_session
def get_permitted_dag_ids(
self,
diff --git a/airflow/auth/managers/models/batch_apis.py
b/airflow/auth/managers/models/batch_apis.py
new file mode 100644
index 0000000000..7cb16339a7
--- /dev/null
+++ b/airflow/auth/managers/models/batch_apis.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypedDict
+
+if TYPE_CHECKING:
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
+ from airflow.auth.managers.models.base_user import BaseUser
+ from airflow.auth.managers.models.resource_details import (
+ ConnectionDetails,
+ DagAccessEntity,
+ DagDetails,
+ PoolDetails,
+ VariableDetails,
+ )
+
+
+class IsAuthorizedConnectionRequest(TypedDict, total=False):
+ """Represent the parameters of ``is_authorized_connection`` API in the
auth manager."""
+
+ method: ResourceMethod
+ details: ConnectionDetails | None
+ user: BaseUser | None
+
+
+class IsAuthorizedDagRequest(TypedDict, total=False):
+ """Represent the parameters of ``is_authorized_dag`` API in the auth
manager."""
+
+ method: ResourceMethod
+ access_entity: DagAccessEntity | None
+ details: DagDetails | None
+ user: BaseUser | None
+
+
+class IsAuthorizedPoolRequest(TypedDict, total=False):
+ """Represent the parameters of ``is_authorized_pool`` API in the auth
manager."""
+
+ method: ResourceMethod
+ details: PoolDetails | None
+ user: BaseUser | None
+
+
+class IsAuthorizedVariableRequest(TypedDict, total=False):
+ """Represent the parameters of ``is_authorized_variable`` API in the auth
manager."""
+
+ method: ResourceMethod
+ details: VariableDetails | None
+ user: BaseUser | None
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index e285d517ba..1ad6e6dab5 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -45,6 +45,12 @@ from airflow.www.extensions.init_auth_manager import
get_auth_manager
if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
+ from airflow.auth.managers.models.batch_apis import (
+ IsAuthorizedConnectionRequest,
+ IsAuthorizedDagRequest,
+ IsAuthorizedPoolRequest,
+ IsAuthorizedVariableRequest,
+ )
from airflow.models import DagRun, Pool, SlaMiss, TaskInstance, Variable
from airflow.models.connection import Connection
from airflow.models.xcom import BaseXCom
@@ -190,15 +196,14 @@ def has_access_connection(method: ResourceMethod) ->
Callable[[T], T]:
@wraps(func)
def decorated(*args, **kwargs):
connections: set[Connection] = set(args[1])
- connections_details = [
- ConnectionDetails(conn_id=connection.conn_id) for connection
in connections
+ requests: Sequence[IsAuthorizedConnectionRequest] = [
+ {
+ "method": method,
+ "details": ConnectionDetails(conn_id=connection.conn_id),
+ }
+ for connection in connections
]
- is_authorized = all(
- [
- get_auth_manager().is_authorized_connection(method=method,
details=connection_details)
- for connection_details in connections_details
- ]
- )
+ is_authorized =
get_auth_manager().batch_is_authorized_connection(requests)
return _has_access(
is_authorized=is_authorized,
func=func,
@@ -265,15 +270,16 @@ def has_access_dag_entities(method: ResourceMethod,
access_entity: DagAccessEnti
@wraps(func)
def decorated(*args, **kwargs):
items: set[SlaMiss | BaseXCom | DagRun | TaskInstance] =
set(args[1])
- dags_details = [DagDetails(id=item.dag_id) for item in items if
item is not None]
- is_authorized = all(
- [
- get_auth_manager().is_authorized_dag(
- method=method, access_entity=access_entity,
details=dag_details
- )
- for dag_details in dags_details
- ]
- )
+ requests: Sequence[IsAuthorizedDagRequest] = [
+ {
+ "method": method,
+ "access_entity": access_entity,
+ "details": DagDetails(id=item.dag_id),
+ }
+ for item in items
+ if item is not None
+ ]
+ is_authorized =
get_auth_manager().batch_is_authorized_dag(requests)
return _has_access(
is_authorized=is_authorized,
func=func,
@@ -296,13 +302,14 @@ def has_access_pool(method: ResourceMethod) ->
Callable[[T], T]:
@wraps(func)
def decorated(*args, **kwargs):
pools: set[Pool] = set(args[1])
- pools_details = [PoolDetails(name=pool.pool) for pool in pools]
- is_authorized = all(
- [
- get_auth_manager().is_authorized_pool(method=method,
details=pool_details)
- for pool_details in pools_details
- ]
- )
+ requests: Sequence[IsAuthorizedPoolRequest] = [
+ {
+ "method": method,
+ "details": PoolDetails(name=pool.pool),
+ }
+ for pool in pools
+ ]
+ is_authorized =
get_auth_manager().batch_is_authorized_pool(requests)
return _has_access(
is_authorized=is_authorized,
func=func,
@@ -324,13 +331,14 @@ def has_access_variable(method: ResourceMethod) ->
Callable[[T], T]:
is_authorized =
get_auth_manager().is_authorized_variable(method=method)
else:
variables: set[Variable] = set(args[1])
- variables_details = [VariableDetails(key=variable.key) for
variable in variables]
- is_authorized = all(
- [
-
get_auth_manager().is_authorized_variable(method=method,
details=variable_details)
- for variable_details in variables_details
- ]
- )
+ requests: Sequence[IsAuthorizedVariableRequest] = [
+ {
+ "method": method,
+ "details": VariableDetails(key=variable.key),
+ }
+ for variable in variables
+ ]
+ is_authorized =
get_auth_manager().batch_is_authorized_variable(requests)
return _has_access(
is_authorized=is_authorized,
func=func,
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 9f913ac9fe..3125031f56 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -994,7 +994,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
)
assert response.status_code == 403
assert response.json == {
- "detail": "User not allowed to access these DAGs:
['example_skip_dag']",
+ "detail": "User not allowed to access some of these DAGs:
['example_python_operator', 'example_skip_dag']",
"status": 403,
"title": "Forbidden",
"type": EXCEPTIONS_LINK_MAP[403],
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index f009fb055b..9b1b5659f4 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -17,12 +17,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING
-from unittest.mock import MagicMock, Mock
+from unittest.mock import MagicMock, Mock, patch
import pytest
from flask import Flask
from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.models.resource_details import (
+ ConnectionDetails,
+ DagDetails,
+ PoolDetails,
+ VariableDetails,
+)
from airflow.exceptions import AirflowException
from airflow.security import permissions
from airflow.www.extensions.init_appbuilder import init_appbuilder
@@ -33,12 +39,8 @@ if TYPE_CHECKING:
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
- ConnectionDetails,
DagAccessEntity,
- DagDetails,
DatasetDetails,
- PoolDetails,
- VariableDetails,
)
@@ -143,6 +145,86 @@ class TestBaseAuthManager:
fab_resource_name=permissions.RESOURCE_MY_PASSWORD,
)
+ @pytest.mark.parametrize(
+ "return_values, expected",
+ [
+ ([False, False], False),
+ ([True, False], False),
+ ([True, True], True),
+ ],
+ )
+ @patch.object(EmptyAuthManager, "is_authorized_dag")
+ def test_batch_is_authorized_dag(self, mock_is_authorized_dag,
auth_manager, return_values, expected):
+ mock_is_authorized_dag.side_effect = return_values
+ result = auth_manager.batch_is_authorized_dag(
+ [
+ {"method": "GET", "details": DagDetails(id="dag1")},
+ {"method": "GET", "details": DagDetails(id="dag2")},
+ ]
+ )
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "return_values, expected",
+ [
+ ([False, False], False),
+ ([True, False], False),
+ ([True, True], True),
+ ],
+ )
+ @patch.object(EmptyAuthManager, "is_authorized_connection")
+ def test_batch_is_authorized_connection(
+ self, mock_is_authorized_connection, auth_manager, return_values,
expected
+ ):
+ mock_is_authorized_connection.side_effect = return_values
+ result = auth_manager.batch_is_authorized_connection(
+ [
+ {"method": "GET", "details":
ConnectionDetails(conn_id="conn1")},
+ {"method": "GET", "details":
ConnectionDetails(conn_id="conn2")},
+ ]
+ )
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "return_values, expected",
+ [
+ ([False, False], False),
+ ([True, False], False),
+ ([True, True], True),
+ ],
+ )
+ @patch.object(EmptyAuthManager, "is_authorized_pool")
+ def test_batch_is_authorized_pool(self, mock_is_authorized_pool,
auth_manager, return_values, expected):
+ mock_is_authorized_pool.side_effect = return_values
+ result = auth_manager.batch_is_authorized_pool(
+ [
+ {"method": "GET", "details": PoolDetails(name="pool1")},
+ {"method": "GET", "details": PoolDetails(name="pool2")},
+ ]
+ )
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "return_values, expected",
+ [
+ ([False, False], False),
+ ([True, False], False),
+ ([True, True], True),
+ ],
+ )
+ @patch.object(EmptyAuthManager, "is_authorized_variable")
+ def test_batch_is_authorized_variable(
+ self, mock_is_authorized_variable, auth_manager, return_values,
expected
+ ):
+ mock_is_authorized_variable.side_effect = return_values
+ result = auth_manager.batch_is_authorized_variable(
+ [
+ {"method": "GET", "details": VariableDetails(key="var1")},
+ {"method": "GET", "details": VariableDetails(key="var2")},
+ ]
+ )
+ assert result == expected
+
@pytest.mark.db_test
def test_security_manager_return_default_security_manager(self,
auth_manager_with_appbuilder):
assert isinstance(auth_manager_with_appbuilder.security_manager,
AirflowSecurityManagerV2)
diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py
index 83fb0eebdf..dd0188974b 100644
--- a/tests/www/test_auth.py
+++ b/tests/www/test_auth.py
@@ -115,9 +115,13 @@ class TestHasAccessNoDetails:
@pytest.mark.parametrize(
"decorator_name, is_authorized_method_name, items",
[
- ("has_access_connection", "is_authorized_connection",
[Connection("conn_1"), Connection("conn_2")]),
- ("has_access_pool", "is_authorized_pool", [Pool(pool="pool_1"),
Pool(pool="pool_2")]),
- ("has_access_variable", "is_authorized_variable", [Variable("var_1"),
Variable("var_2")]),
+ (
+ "has_access_connection",
+ "batch_is_authorized_connection",
+ [Connection("conn_1"), Connection("conn_2")],
+ ),
+ ("has_access_pool", "batch_is_authorized_pool", [Pool(pool="pool_1"),
Pool(pool="pool_2")]),
+ ("has_access_variable", "batch_is_authorized_variable",
[Variable("var_1"), Variable("var_2")]),
],
)
class TestHasAccessWithDetails:
@@ -181,7 +185,7 @@ class TestHasAccessDagEntities:
@patch("airflow.www.auth.get_auth_manager")
def test_has_access_dag_entities_when_authorized(self,
mock_get_auth_manager, dag_access_entity):
auth_manager = Mock()
- auth_manager.is_authorized_dag.return_value = True
+ auth_manager.batch_is_authorized_dag.return_value = True
mock_get_auth_manager.return_value = auth_manager
items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]
@@ -194,7 +198,7 @@ class TestHasAccessDagEntities:
@patch("airflow.www.auth.get_auth_manager")
def test_has_access_dag_entities_when_unauthorized(self,
mock_get_auth_manager, app, dag_access_entity):
auth_manager = Mock()
- auth_manager.is_authorized_dag.return_value = False
+ auth_manager.batch_is_authorized_dag.return_value = False
mock_get_auth_manager.return_value = auth_manager
items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")]