This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7114f79dab7 List only connections, pools and variables the user has
access to (#55298)
7114f79dab7 is described below
commit 7114f79dab7cfc058e0bbd8a393a309bb6d86ef2
Author: Vincent <[email protected]>
AuthorDate: Mon Sep 15 10:38:23 2025 -0400
List only connections, pools and variables the user has access to (#55298)
---
.../docs/core-concepts/auth-manager/index.rst | 16 +-
.../api_fastapi/auth/managers/base_auth_manager.py | 243 +++++++++++++++++++--
.../core_api/routes/public/connections.py | 9 +-
.../api_fastapi/core_api/routes/public/pools.py | 9 +-
.../core_api/routes/public/variables.py | 11 +-
.../src/airflow/api_fastapi/core_api/security.py | 94 ++++++++
.../auth/managers/test_base_auth_manager.py | 208 +++++++++++++++++-
.../core_api/routes/public/test_connections.py | 14 ++
.../core_api/routes/public/test_dags.py | 2 +-
.../core_api/routes/public/test_pools.py | 14 ++
.../core_api/routes/public/test_variables.py | 14 ++
.../amazon/aws/auth_manager/aws_auth_manager.py | 1 +
12 files changed, 592 insertions(+), 43 deletions(-)
diff --git a/airflow-core/docs/core-concepts/auth-manager/index.rst
b/airflow-core/docs/core-concepts/auth-manager/index.rst
index fffc89fc4d5..13a20eeaa3c 100644
--- a/airflow-core/docs/core-concepts/auth-manager/index.rst
+++ b/airflow-core/docs/core-concepts/auth-manager/index.rst
@@ -178,14 +178,14 @@ Optional methods recommended to override for optimization
The following methods aren't required to override to have a functional Airflow
auth manager. However, it is recommended to override these to make your auth
manager faster (and potentially less costly):
-* ``batch_is_authorized_connection``: Batch version of
``is_authorized_connection``. If not overridden, it will call
``is_authorized_connection`` for every single item.
-* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not
overridden, it will call ``is_authorized_dag`` for every single item.
-* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If
not overridden, it will call ``is_authorized_pool`` for every single item.
-* ``batch_is_authorized_variable``: Batch version of
``is_authorized_variable``. If not overridden, it will call
``is_authorized_variable`` for every single item.
-* ``get_authorized_dag_ids``: Return the list of Dag IDs the user has access
to. If not overridden, it will call ``is_authorized_dag`` for every single Dag
available in the environment.
-
- * Note: To filter the results of ``get_authorized_dag_ids``, it is
recommended that you define the filtering logic in your
``filter_authorized_dag_ids`` method. For example, this may be useful if you
rely on per-Dag access controls derived from one or more fields on a given Dag
(e.g. Dag tags).
- * This method requires an active session with the Airflow metadata database.
As such, overriding the ``get_authorized_dag_ids`` method is an advanced use
case, which should be considered carefully -- it is recommended you refer to
the :doc:`../../database-erd-ref`.
+* ``batch_is_authorized_connection``: Batch version of
``is_authorized_connection``. If not overridden, it calls
``is_authorized_connection`` for every single item.
+* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not
overridden, it calls ``is_authorized_dag`` for every single item.
+* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If
not overridden, it calls ``is_authorized_pool`` for every single item.
+* ``batch_is_authorized_variable``: Batch version of
``is_authorized_variable``. If not overridden, it calls
``is_authorized_variable`` for every single item.
+* ``filter_authorized_connections``: Given a list of connection IDs
(``conn_id``), return the list of connection IDs the user has access to. If
not overridden, it calls ``is_authorized_connection`` for every single
connection passed as parameter.
+* ``filter_authorized_dag_ids``: Given a list of Dag IDs, return the list of
Dag IDs the user has access to. If not overridden, it calls
``is_authorized_dag`` for every single Dag passes as parameter.
+* ``filter_authorized_pools``: Given a list of pool names, return the list of
pool names the user has access to. If not overridden, it calls
``is_authorized_pool`` for every single pool passed as parameter.
+* ``filter_authorized_variables``: Given a list of variable keys, return the
list of variable keys the user has access to. If not overridden, it calls
``is_authorized_variable`` for every single variable passed as parameter.
CLI
^^^
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index d57dd3cdc39..456b29c0cf4 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
from abc import ABCMeta, abstractmethod
+from collections import defaultdict
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
@@ -26,7 +27,13 @@ from jwt import InvalidTokenError
from sqlalchemy import select
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
-from airflow.api_fastapi.auth.managers.models.resource_details import
BackfillDetails, DagDetails
+from airflow.api_fastapi.auth.managers.models.resource_details import (
+ BackfillDetails,
+ ConnectionDetails,
+ DagDetails,
+ PoolDetails,
+ VariableDetails,
+)
from airflow.api_fastapi.auth.tokens import (
JWTGenerator,
JWTValidator,
@@ -35,7 +42,9 @@ from airflow.api_fastapi.auth.tokens import (
)
from airflow.api_fastapi.common.types import ExtraMenuItem, MenuItem
from airflow.configuration import conf
-from airflow.models import DagModel
+from airflow.models import Connection, DagModel, Pool, Variable
+from airflow.models.dagbundle import DagBundleModel
+from airflow.models.team import Team, dag_bundle_team_association_table
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
@@ -56,10 +65,7 @@ if TYPE_CHECKING:
AssetAliasDetails,
AssetDetails,
ConfigurationDetails,
- ConnectionDetails,
DagAccessEntity,
- PoolDetails,
- VariableDetails,
)
from airflow.cli.cli_config import CLICommand
@@ -427,16 +433,34 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
"""
Get DAGs the user has access to.
- By default, reads all the DAGs and check individually if the user has
permissions to access the DAG.
- Can lead to some poor performance. It is recommended to override this
method in the auth manager
- implementation to provide a more efficient implementation.
-
:param user: the user
:param method: the method to filter on
:param session: the session
"""
- dag_ids = {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
- return self.filter_authorized_dag_ids(dag_ids=dag_ids, method=method,
user=user)
+ stmt = (
+ select(DagModel.dag_id, Team.name)
+ .join(DagBundleModel, DagModel.bundle_name == DagBundleModel.name)
+ .join(
+ dag_bundle_team_association_table,
+ DagBundleModel.name ==
dag_bundle_team_association_table.c.dag_bundle_name,
+ isouter=True,
+ )
+ .join(Team, Team.id ==
dag_bundle_team_association_table.c.team_id, isouter=True)
+ )
+ rows = session.execute(stmt).all()
+ dags_by_team: dict[str | None, set[str]] = defaultdict(set)
+ for dag_id, team_name in rows:
+ dags_by_team[team_name].add(dag_id)
+
+ dag_ids: set[str] = set()
+ for team_name, team_dag_ids in dags_by_team.items():
+ dag_ids.update(
+ self.filter_authorized_dag_ids(
+ dag_ids=team_dag_ids, user=user, method=method,
team_name=team_name
+ )
+ )
+
+ return dag_ids
def filter_authorized_dag_ids(
self,
@@ -444,19 +468,208 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
dag_ids: set[str],
user: T,
method: ResourceMethod = "GET",
+ team_name: str | None = None,
) -> set[str]:
"""
Filter DAGs the user has access to.
- :param dag_ids: the list of DAG ids
+ By default, check individually if the user has permissions to access
the DAG.
+ Can lead to some poor performance. It is recommended to override this
method in the auth manager
+ implementation to provide a more efficient implementation.
+
+ :param dag_ids: the set of DAG ids
+ :param user: the user
+ :param method: the method to filter on
+ :param team_name: the name of the team associated to the Dags if
Airflow environment runs in
+ multi-team mode
+ """
+
+ def _is_authorized_dag_id(dag_id: str):
+ return self.is_authorized_dag(
+ method=method, details=DagDetails(id=dag_id,
team_name=team_name), user=user
+ )
+
+ return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(dag_id)}
+
+ @provide_session
+ def get_authorized_connections(
+ self,
+ *,
+ user: T,
+ method: ResourceMethod = "GET",
+ session: Session = NEW_SESSION,
+ ) -> set[str]:
+ """
+ Get connection ids (``conn_id``) the user has access to.
+
+ :param user: the user
+ :param method: the method to filter on
+ :param session: the session
+ """
+ stmt = select(Connection.conn_id, Team.name).join(Team,
Connection.team_id == Team.id, isouter=True)
+ rows = session.execute(stmt).all()
+ connections_by_team: dict[str | None, set[str]] = defaultdict(set)
+ for conn_id, team_name in rows:
+ connections_by_team[team_name].add(conn_id)
+
+ conn_ids: set[str] = set()
+ for team_name, team_conn_ids in connections_by_team.items():
+ conn_ids.update(
+ self.filter_authorized_connections(
+ conn_ids=team_conn_ids, user=user, method=method,
team_name=team_name
+ )
+ )
+
+ return conn_ids
+
+ def filter_authorized_connections(
+ self,
+ *,
+ conn_ids: set[str],
+ user: T,
+ method: ResourceMethod = "GET",
+ team_name: str | None = None,
+ ) -> set[str]:
+ """
+ Filter connections the user has access to.
+
+ By default, check individually if the user has permissions to access
the connection.
+ Can lead to some poor performance. It is recommended to override this
method in the auth manager
+ implementation to provide a more efficient implementation.
+
+ :param conn_ids: the set of connection ids (``conn_id``)
:param user: the user
:param method: the method to filter on
+ :param team_name: the name of the team associated to the connections
if Airflow environment runs in
+ multi-team mode
+ """
+
+ def _is_authorized_connection(conn_id: str):
+ return self.is_authorized_connection(
+ method=method, details=ConnectionDetails(conn_id=conn_id,
team_name=team_name), user=user
+ )
+
+ return {conn_id for conn_id in conn_ids if
_is_authorized_connection(conn_id)}
+
+ @provide_session
+ def get_authorized_variables(
+ self,
+ *,
+ user: T,
+ method: ResourceMethod = "GET",
+ session: Session = NEW_SESSION,
+ ) -> set[str]:
"""
+ Get variable keys the user has access to.
- def _is_authorized_dag_id(method: ResourceMethod, dag_id: str):
- return self.is_authorized_dag(method=method,
details=DagDetails(id=dag_id), user=user)
+ :param user: the user
+ :param method: the method to filter on
+ :param session: the session
+ """
+ stmt = select(Variable.key, Team.name).join(Team, Variable.team_id ==
Team.id, isouter=True)
+ rows = session.execute(stmt).all()
+ variables_by_team: dict[str | None, set[str]] = defaultdict(set)
+ for var_key, team_name in rows:
+ variables_by_team[team_name].add(var_key)
+
+ var_keys: set[str] = set()
+ for team_name, team_var_keys in variables_by_team.items():
+ var_keys.update(
+ self.filter_authorized_variables(
+ variable_keys=team_var_keys, user=user, method=method,
team_name=team_name
+ )
+ )
+
+ return var_keys
+
+ def filter_authorized_variables(
+ self,
+ *,
+ variable_keys: set[str],
+ user: T,
+ method: ResourceMethod = "GET",
+ team_name: str | None = None,
+ ) -> set[str]:
+ """
+ Filter variables the user has access to.
+
+ By default, check individually if the user has permissions to access
the variable.
+ Can lead to some poor performance. It is recommended to override this
method in the auth manager
+ implementation to provide a more efficient implementation.
+
+ :param variable_keys: the set of variable keys
+ :param user: the user
+ :param method: the method to filter on
+ :param team_name: the name of the team associated to the connections
if Airflow environment runs in
+ multi-team mode
+ """
+
+ def _is_authorized_variable(var_key: str):
+ return self.is_authorized_variable(
+ method=method, details=VariableDetails(key=var_key,
team_name=team_name), user=user
+ )
+
+ return {var_key for var_key in variable_keys if
_is_authorized_variable(var_key)}
+
+ @provide_session
+ def get_authorized_pools(
+ self,
+ *,
+ user: T,
+ method: ResourceMethod = "GET",
+ session: Session = NEW_SESSION,
+ ) -> set[str]:
+ """
+ Get pools the user has access to.
+
+ :param user: the user
+ :param method: the method to filter on
+ :param session: the session
+ """
+ stmt = select(Pool.pool, Team.name).join(Team, Pool.team_id ==
Team.id, isouter=True)
+ rows = session.execute(stmt).all()
+ pools_by_team: dict[str | None, set[str]] = defaultdict(set)
+ for pool_name, team_name in rows:
+ pools_by_team[team_name].add(pool_name)
+
+ pool_names: set[str] = set()
+ for team_name, team_pool_names in pools_by_team.items():
+ pool_names.update(
+ self.filter_authorized_pools(
+ pool_names=team_pool_names, user=user, method=method,
team_name=team_name
+ )
+ )
+
+ return pool_names
+
+ def filter_authorized_pools(
+ self,
+ *,
+ pool_names: set[str],
+ user: T,
+ method: ResourceMethod = "GET",
+ team_name: str | None = None,
+ ) -> set[str]:
+ """
+ Filter pools the user has access to.
+
+ By default, check individually if the user has permissions to access
the pool.
+ Can lead to some poor performance. It is recommended to override this
method in the auth manager
+ implementation to provide a more efficient implementation.
+
+ :param pool_names: the set of pool names
+ :param user: the user
+ :param method: the method to filter on
+ :param team_name: the name of the team associated to the connections
if Airflow environment runs in
+ multi-team mode
+ """
+
+ def _is_authorized_pool(name: str):
+ return self.is_authorized_pool(
+ method=method, details=PoolDetails(name=name,
team_name=team_name), user=user
+ )
- return {dag_id for dag_id in dag_ids if _is_authorized_dag_id(method,
dag_id)}
+ return {pool_name for pool_name in pool_names if
_is_authorized_pool(pool_name)}
@staticmethod
def get_cli_commands() -> list[CLICommand]:
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
index 295c6f156f0..28db5ae4fef 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -43,7 +43,11 @@ from airflow.api_fastapi.core_api.datamodels.connections
import (
ConnectionTestResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.security import requires_access_connection,
requires_access_connection_bulk
+from airflow.api_fastapi.core_api.security import (
+ ReadableConnectionsFilterDep,
+ requires_access_connection,
+ requires_access_connection_bulk,
+)
from airflow.api_fastapi.core_api.services.public.connections import (
BulkConnectionService,
update_orm_from_pydantic,
@@ -117,13 +121,14 @@ def get_connections(
).dynamic_depends()
),
],
+ readable_connections_filter: ReadableConnectionsFilterDep,
session: SessionDep,
connection_id_pattern: QueryConnectionIdPatternSearch,
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
statement=select(Connection),
- filters=[connection_id_pattern],
+ filters=[connection_id_pattern, readable_connections_filter],
order_by=order_by,
offset=offset,
limit=limit,
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py
index 835c2a62c3f..6a7e2f646d9 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/pools.py
@@ -40,7 +40,11 @@ from airflow.api_fastapi.core_api.datamodels.pools import (
PoolResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.security import requires_access_pool,
requires_access_pool_bulk
+from airflow.api_fastapi.core_api.security import (
+ ReadablePoolsFilterDep,
+ requires_access_pool,
+ requires_access_pool_bulk,
+)
from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.models.pool import Pool
@@ -103,12 +107,13 @@ def get_pools(
Depends(SortParam(["id", "pool"], Pool, to_replace={"name":
"pool"}).dynamic_depends()),
],
pool_name_pattern: QueryPoolNamePatternSearch,
+ readable_pools_filter: ReadablePoolsFilterDep,
session: SessionDep,
) -> PoolCollectionResponse:
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
statement=select(Pool),
- filters=[pool_name_pattern],
+ filters=[pool_name_pattern, readable_pools_filter],
order_by=order_by,
offset=offset,
limit=limit,
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py
index eb111c0c6af..36fc5be44b1 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/variables.py
@@ -38,7 +38,11 @@ from airflow.api_fastapi.core_api.datamodels.variables
import (
VariableResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.security import requires_access_variable,
requires_access_variable_bulk
+from airflow.api_fastapi.core_api.security import (
+ ReadableVariablesFilterDep,
+ requires_access_variable,
+ requires_access_variable_bulk,
+)
from airflow.api_fastapi.core_api.services.public.variables import
BulkVariableService
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.models.variable import Variable
@@ -99,13 +103,14 @@ def get_variables(
).dynamic_depends()
),
],
+ readable_variables_filter: ReadableVariablesFilterDep,
session: SessionDep,
- varaible_key_pattern: QueryVariableKeyPatternSearch,
+ variable_key_pattern: QueryVariableKeyPatternSearch,
) -> VariableCollectionResponse:
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
statement=select(Variable),
- filters=[varaible_key_pattern],
+ filters=[variable_key_pattern, readable_variables_filter],
order_by=order_by,
offset=offset,
limit=limit,
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py
b/airflow-core/src/airflow/api_fastapi/core_api/security.py
index 154e040ed68..b24313319a9 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py
@@ -233,6 +233,36 @@ def requires_access_backfill(method: ResourceMethod) ->
Callable[[Request, BaseU
return inner
+class PermittedPoolFilter(OrmClause[set[str]]):
+ """A parameter that filters the permitted pools for the user."""
+
+ def to_orm(self, select: Select) -> Select:
+ return select.where(Pool.pool.in_(self.value))
+
+
+def permitted_pool_filter_factory(
+ method: ResourceMethod,
+) -> Callable[[Request, BaseUser], PermittedPoolFilter]:
+ """
+ Create a callable for Depends in FastAPI that returns a filter of the
permitted pools for the user.
+
+ :param method: whether filter readable or writable.
+ """
+
+ def depends_permitted_pools_filter(
+ request: Request,
+ user: GetUserDep,
+ ) -> PermittedPoolFilter:
+ auth_manager: BaseAuthManager = request.app.state.auth_manager
+ authorized_pools: set[str] =
auth_manager.get_authorized_pools(user=user, method=method)
+ return PermittedPoolFilter(authorized_pools)
+
+ return depends_permitted_pools_filter
+
+
+ReadablePoolsFilterDep = Annotated[PermittedPoolFilter,
Depends(permitted_pool_filter_factory("GET"))]
+
+
def requires_access_pool(method: ResourceMethod) -> Callable[[Request,
BaseUser], None]:
def inner(
request: Request,
@@ -294,6 +324,38 @@ def requires_access_pool_bulk() ->
Callable[[BulkBody[PoolBody], BaseUser], None
return inner
+class PermittedConnectionFilter(OrmClause[set[str]]):
+ """A parameter that filters the permitted connections for the user."""
+
+ def to_orm(self, select: Select) -> Select:
+ return select.where(Connection.conn_id.in_(self.value))
+
+
+def permitted_connection_filter_factory(
+ method: ResourceMethod,
+) -> Callable[[Request, BaseUser], PermittedConnectionFilter]:
+ """
+ Create a callable for Depends in FastAPI that returns a filter of the
permitted connections for the user.
+
+ :param method: whether filter readable or writable.
+ """
+
+ def depends_permitted_connections_filter(
+ request: Request,
+ user: GetUserDep,
+ ) -> PermittedConnectionFilter:
+ auth_manager: BaseAuthManager = request.app.state.auth_manager
+ authorized_connections: set[str] =
auth_manager.get_authorized_connections(user=user, method=method)
+ return PermittedConnectionFilter(authorized_connections)
+
+ return depends_permitted_connections_filter
+
+
+ReadableConnectionsFilterDep = Annotated[
+ PermittedConnectionFilter,
Depends(permitted_connection_filter_factory("GET"))
+]
+
+
def requires_access_connection(method: ResourceMethod) -> Callable[[Request,
BaseUser], None]:
def inner(
request: Request,
@@ -379,6 +441,38 @@ def requires_access_configuration(method: ResourceMethod)
-> Callable[[Request,
return inner
+class PermittedVariableFilter(OrmClause[set[str]]):
+ """A parameter that filters the permitted variables for the user."""
+
+ def to_orm(self, select: Select) -> Select:
+ return select.where(Variable.key.in_(self.value))
+
+
+def permitted_variable_filter_factory(
+ method: ResourceMethod,
+) -> Callable[[Request, BaseUser], PermittedVariableFilter]:
+ """
+ Create a callable for Depends in FastAPI that returns a filter of the
permitted variables for the user.
+
+ :param method: whether filter readable or writable.
+ """
+
+ def depends_permitted_variables_filter(
+ request: Request,
+ user: GetUserDep,
+ ) -> PermittedVariableFilter:
+ auth_manager: BaseAuthManager = request.app.state.auth_manager
+ authorized_variables: set[str] =
auth_manager.get_authorized_variables(user=user, method=method)
+ return PermittedVariableFilter(authorized_variables)
+
+ return depends_permitted_variables_filter
+
+
+ReadableVariablesFilterDep = Annotated[
+ PermittedVariableFilter, Depends(permitted_variable_filter_factory("GET"))
+]
+
+
def requires_access_variable(method: ResourceMethod) -> Callable[[Request,
BaseUser], None]:
def inner(
request: Request,
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
index fde84665609..010cf24481f 100644
---
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
+++
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
@@ -326,42 +326,226 @@ class TestBaseAuthManager:
assert result == expected
@pytest.mark.parametrize(
- "access_per_dag, dag_ids, expected",
+ "access_per_dag, access_per_team, rows, expected",
[
+ # Without teams
# No access to any dag
(
{},
- ["dag1", "dag2"],
+ {},
+ [("dag1", None), ("dag2", None)],
set(),
),
# Access to specific dags
(
{"dag1": True},
- ["dag1", "dag2"],
+ {},
+ [("dag1", None), ("dag2", None)],
{"dag1"},
),
+ # With teams
+ # No access to any dag
+ (
+ {},
+ {},
+ [("dag1", "team1"), ("dag2", "team2")],
+ set(),
+ ),
+ # Access to a specific team
+ (
+ {},
+ {"team1": True},
+ [("dag1", "team1"), ("dag2", "team1"), ("dag3", "team2")],
+ {"dag1", "dag2"},
+ ),
],
)
- def test_get_authorized_dag_ids(self, auth_manager, access_per_dag: dict,
dag_ids: list, expected: set):
+ def test_get_authorized_dag_ids(
+ self, auth_manager, access_per_dag: dict, access_per_team: dict, rows:
list, expected: set
+ ):
def side_effect_func(
*,
method: ResourceMethod,
+ user: BaseAuthManagerUserTest,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
- user: BaseAuthManagerUserTest | None = None,
):
if not details:
return False
- return access_per_dag.get(details.id, False)
+ return access_per_dag.get(details.id, False) or
access_per_team.get(details.team_name, False)
auth_manager.is_authorized_dag =
MagicMock(side_effect=side_effect_func)
user = Mock()
session = Mock()
- dags = []
- for dag_id in dag_ids:
- mock = Mock()
- mock.dag_id = dag_id
- dags.append(mock)
- session.execute.return_value = dags
+ session.execute.return_value.all.return_value = rows
result = auth_manager.get_authorized_dag_ids(user=user,
session=session)
assert result == expected
+
+ @pytest.mark.parametrize(
+ "access_per_connection, access_per_team, rows, expected",
+ [
+ # Without teams
+ # No access to any connection
+ (
+ {},
+ {},
+ [("conn1", None), ("conn2", None)],
+ set(),
+ ),
+ # Access to specific connections
+ (
+ {"conn1": True},
+ {},
+ [("conn1", None), ("conn2", None)],
+ {"conn1"},
+ ),
+ # With teams
+ # No access to any connection
+ (
+ {},
+ {},
+ [("conn1", "team1"), ("conn2", "team2")],
+ set(),
+ ),
+ # Access to a specific team
+ (
+ {},
+ {"team1": True},
+ [("conn1", "team1"), ("conn2", "team1"), ("conn3", "team2")],
+ {"conn1", "conn2"},
+ ),
+ ],
+ )
+ def test_get_authorized_connections(
+ self, auth_manager, access_per_connection: dict, access_per_team:
dict, rows: list, expected: set
+ ):
+ def side_effect_func(
+ *,
+ method: ResourceMethod,
+ user: BaseAuthManagerUserTest,
+ details: ConnectionDetails | None = None,
+ ):
+ if not details:
+ return False
+ return access_per_connection.get(details.conn_id, False) or
access_per_team.get(
+ details.team_name, False
+ )
+
+ auth_manager.is_authorized_connection =
MagicMock(side_effect=side_effect_func)
+ user = Mock()
+ session = Mock()
+ session.execute.return_value.all.return_value = rows
+ result = auth_manager.get_authorized_connections(user=user,
session=session)
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "access_per_variable, access_per_team, rows, expected",
+ [
+ # Without teams
+ # No access to any variable
+ (
+ {},
+ {},
+ [("var1", None), ("var2", None)],
+ set(),
+ ),
+ # Access to specific variables
+ (
+ {"var1": True},
+ {},
+ [("var1", None), ("var2", None)],
+ {"var1"},
+ ),
+ # With teams
+ # No access to any variable
+ (
+ {},
+ {},
+ [("var1", "team1"), ("var2", "team2")],
+ set(),
+ ),
+ # Access to a specific team
+ (
+ {},
+ {"team1": True},
+ [("var1", "team1"), ("var2", "team1"), ("var3", "team2")],
+ {"var1", "var2"},
+ ),
+ ],
+ )
+ def test_get_authorized_variables(
+ self, auth_manager, access_per_variable: dict, access_per_team: dict,
rows: list, expected: set
+ ):
+ def side_effect_func(
+ *,
+ method: ResourceMethod,
+ user: BaseAuthManagerUserTest,
+ details: VariableDetails | None = None,
+ ):
+ if not details:
+ return False
+ return access_per_variable.get(details.key, False) or
access_per_team.get(
+ details.team_name, False
+ )
+
+ auth_manager.is_authorized_variable =
MagicMock(side_effect=side_effect_func)
+ user = Mock()
+ session = Mock()
+ session.execute.return_value.all.return_value = rows
+ result = auth_manager.get_authorized_variables(user=user,
session=session)
+ assert result == expected
+
+ @pytest.mark.parametrize(
+ "access_per_pool, access_per_team, rows, expected",
+ [
+ # Without teams
+ # No access to any pool
+ (
+ {},
+ {},
+ [("pool1", None), ("pool2", None)],
+ set(),
+ ),
+ # Access to specific pools
+ (
+ {"pool1": True},
+ {},
+ [("pool1", None), ("pool2", None)],
+ {"pool1"},
+ ),
+ # With teams
+ # No access to any pool
+ (
+ {},
+ {},
+ [("pool1", "team1"), ("pool2", "team2")],
+ set(),
+ ),
+ # Access to a specific team
+ (
+ {},
+ {"team1": True},
+ [("pool1", "team1"), ("pool2", "team1"), ("pool3", "team2")],
+ {"pool1", "pool2"},
+ ),
+ ],
+ )
+ def test_get_authorized_pools(
+ self, auth_manager, access_per_pool: dict, access_per_team: dict,
rows: list, expected: set
+ ):
+ def side_effect_func(
+ *,
+ method: ResourceMethod,
+ user: BaseAuthManagerUserTest,
+ details: PoolDetails | None = None,
+ ):
+ if not details:
+ return False
+ return access_per_pool.get(details.name, False) or
access_per_team.get(details.team_name, False)
+
+ auth_manager.is_authorized_pool =
MagicMock(side_effect=side_effect_func)
+ user = Mock()
+ session = Mock()
+ session.execute.return_value.all.return_value = rows
+ result = auth_manager.get_authorized_pools(user=user, session=session)
+ assert result == expected
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
index 1dc77cede75..b8f60a86762 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
@@ -232,6 +232,20 @@ class TestGetConnections(TestConnectionEndpoint):
response = unauthorized_test_client.get("/connections", params={})
assert response.status_code == 403
+ @mock.patch(
+
"airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_connections"
+ )
+ def test_should_call_get_authorized_connections(self,
mock_get_authorized_connections, test_client):
+ self.create_connections()
+ mock_get_authorized_connections.return_value = {TEST_CONN_ID}
+ response = test_client.get("/connections")
+ mock_get_authorized_connections.assert_called_once_with(user=mock.ANY,
method="GET")
+ assert response.status_code == 200
+ body = response.json()
+
+ assert body["total_entries"] == 1
+ assert [connection["connection_id"] for connection in
body["connections"]] == [TEST_CONN_ID]
+
class TestPostConnection(TestConnectionEndpoint):
@pytest.mark.parametrize(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
index 32f337b2690..564a4b27348 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py
@@ -450,7 +450,7 @@ class TestGetDags(TestDagEndpoint):
assert actual_ids == expected_ids
@mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_dag_ids")
- def test_get_dags_should_call_authorized_dag_ids(self,
mock_get_authorized_dag_ids, test_client):
+ def test_get_dags_should_call_get_authorized_dag_ids(self,
mock_get_authorized_dag_ids, test_client):
mock_get_authorized_dag_ids.return_value = {DAG1_ID, DAG2_ID}
response = test_client.get("/dags")
mock_get_authorized_dag_ids.assert_called_once_with(user=mock.ANY,
method="GET")
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
index 8112ed06d76..d4d16abfb8f 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py
@@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
from airflow.models.pool import Pool
@@ -202,6 +204,18 @@ class TestGetPools(TestPoolsEndpoint):
response = unauthorized_test_client.get("/pools",
params={"pool_name_pattern": "~"})
assert response.status_code == 403
+
@mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_pools")
+ def test_should_call_get_authorized_pools(self, mock_get_authorized_pools,
test_client):
+ self.create_pools()
+ mock_get_authorized_pools.return_value = {Pool.DEFAULT_POOL_NAME,
POOL1_NAME}
+ response = test_client.get("/pools")
+ mock_get_authorized_pools.assert_called_once_with(user=mock.ANY,
method="GET")
+ assert response.status_code == 200
+ body = response.json()
+
+ assert body["total_entries"] == 2
+ assert [pool["name"] for pool in body["pools"]] ==
[Pool.DEFAULT_POOL_NAME, POOL1_NAME]
+
class TestPatchPool(TestPoolsEndpoint):
@pytest.mark.parametrize(
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
index 3b26fa87794..076905ea03b 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py
@@ -308,6 +308,20 @@ class TestGetVariables(TestVariableEndpoint):
response = unauthorized_test_client.get("/variables")
assert response.status_code == 403
+ @mock.patch(
+
"airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_authorized_variables"
+ )
+ def test_should_call_get_authorized_variables(self,
mock_get_authorized_variables, test_client):
+ self.create_variables()
+ mock_get_authorized_variables.return_value = {TEST_VARIABLE_KEY,
TEST_VARIABLE_KEY2}
+ response = test_client.get("/variables")
+ mock_get_authorized_variables.assert_called_once_with(user=mock.ANY,
method="GET")
+ assert response.status_code == 200
+ body = response.json()
+
+ assert body["total_entries"] == 2
+ assert [variable["key"] for variable in body["variables"]] ==
[TEST_VARIABLE_KEY, TEST_VARIABLE_KEY2]
+
class TestPatchVariable(TestVariableEndpoint):
@pytest.mark.enable_redact
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 6f3e5851c2f..913ea04fb92 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -344,6 +344,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
dag_ids: set[str],
user: AwsAuthManagerUser,
method: ResourceMethod = "GET",
+ team_name: str | None = None,
):
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] =
defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []