This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e9a895cee53 Pass team name to `is_authorized_connection`, `is_authorized_variable` and `is_authorized_pool` in Airflow API (#55193) e9a895cee53 is described below commit e9a895cee533dd91556f84baba06f60bbd75ece6 Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Thu Sep 4 08:54:31 2025 -0400 Pass team name to `is_authorized_connection`, `is_authorized_variable` and `is_authorized_pool` in Airflow API (#55193) --- .../auth/managers/models/resource_details.py | 3 + .../src/airflow/api_fastapi/core_api/security.py | 12 ++- airflow-core/src/airflow/models/connection.py | 16 +++- airflow-core/src/airflow/models/pool.py | 7 ++ airflow-core/src/airflow/models/variable.py | 14 +++- .../unit/api_fastapi/core_api/test_security.py | 92 +++++++++++++++++++++- airflow-core/tests/unit/models/test_connection.py | 18 +++++ airflow-core/tests/unit/models/test_dag.py | 14 ---- airflow-core/tests/unit/models/test_pool.py | 13 +++ airflow-core/tests/unit/models/test_variable.py | 12 +++ devel-common/src/tests_common/pytest_plugin.py | 18 +++++ 11 files changed, 197 insertions(+), 22 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py index aed68f27aa8..afce15e6b0a 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py @@ -35,6 +35,7 @@ class ConnectionDetails: """Represents the details of a connection.""" conn_id: str | None = None + team_name: str | None = None @dataclass @@ -71,6 +72,7 @@ class PoolDetails: """Represents the details of a pool.""" name: str | None = None + team_name: str | None = None @dataclass @@ -78,6 +80,7 @@ class VariableDetails: """Represents the details of a variable.""" key: str | None = None + team_name: str | None = None class AccessView(Enum): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py b/airflow-core/src/airflow/api_fastapi/core_api/security.py index 900ad3fbb6a..6053900ce33 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py @@ -42,6 +42,7 @@ from airflow.api_fastapi.auth.managers.models.resource_details import ( ) from airflow.api_fastapi.core_api.base import OrmClause from airflow.configuration import conf +from airflow.models import Connection, Pool, Variable from airflow.models.dag import DagModel, DagRun, DagTag from airflow.models.dagwarning import DagWarning from airflow.models.taskinstance import TaskInstance as TI @@ -223,10 +224,11 @@ def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser] user: GetUserDep, ) -> None: pool_name = request.path_params.get("pool_name") + team_name = Pool.get_team_name(pool_name) if pool_name else None _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_pool( - method=method, details=PoolDetails(name=pool_name), user=user + method=method, details=PoolDetails(name=pool_name, team_name=team_name), user=user ) ) @@ -239,10 +241,13 @@ def requires_access_connection(method: ResourceMethod) -> Callable[[Request, Bas user: GetUserDep, ) -> None: connection_id = request.path_params.get("connection_id") + team_name = Connection.get_team_name(connection_id) if connection_id else None _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_connection( - method=method, details=ConnectionDetails(conn_id=connection_id), user=user + method=method, + details=ConnectionDetails(conn_id=connection_id, team_name=team_name), + user=user, ) ) @@ -273,10 +278,11 @@ def requires_access_variable(method: ResourceMethod) -> Callable[[Request, BaseU user: GetUserDep, ) -> None: variable_key: str | None = request.path_params.get("variable_key") + team_name = Variable.get_team_name(variable_key) if variable_key else None _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( - method=method, details=VariableDetails(key=variable_key), user=user + method=method, details=VariableDetails(key=variable_key, team_name=team_name), user=user ), ) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 33d292a645a..1567cab1dd2 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -27,7 +27,7 @@ from json import JSONDecodeError from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit -from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, select from sqlalchemy.orm import declared_attr, reconstructor, synonym from sqlalchemy_utils import UUIDType @@ -36,10 +36,12 @@ from airflow.configuration import ensure_secrets_loaded from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.models.team import Team from airflow.sdk import SecretCache from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +from airflow.utils.session import NEW_SESSION, provide_session log = logging.getLogger(__name__) # sanitize the `conn_id` pattern by allowing alphanumeric characters plus @@ -150,6 +152,7 @@ class Connection(Base, LoggingMixin): port: int | None = None, extra: str | dict | None = None, uri: str | None = None, + team_id: str | None = None, ): super().__init__() self.conn_id = sanitize_conn_id(conn_id) @@ -178,6 +181,7 @@ class Connection(Base, LoggingMixin): if self.password: mask_secret(self.password) mask_secret(quote(self.password)) + self.team_id = team_id @staticmethod def _validate_extra(extra, conn_id) -> None: @@ -584,3 +588,13 @@ class Connection(Base, LoggingMixin): conn_repr = self.to_dict(prune_empty=True, validate=False) conn_repr.pop("conn_id", None) return json.dumps(conn_repr) + + @staticmethod + @provide_session + def get_team_name(connection_id: str, session=NEW_SESSION) -> str | None: + stmt = ( + select(Team.name) + .join(Connection, Team.id == Connection.team_id) + .where(Connection.conn_id == connection_id) + ) + return session.scalar(stmt) diff --git a/airflow-core/src/airflow/models/pool.py b/airflow-core/src/airflow/models/pool.py index 2a8b9b57c3f..7a205036530 100644 --- a/airflow-core/src/airflow/models/pool.py +++ b/airflow-core/src/airflow/models/pool.py @@ -24,6 +24,7 @@ from sqlalchemy_utils import UUIDType from airflow.exceptions import AirflowException, PoolNotFound from airflow.models.base import Base +from airflow.models.team import Team from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils.db import exists_query from airflow.utils.session import NEW_SESSION, provide_session @@ -352,3 +353,9 @@ class Pool(Base): if self.slots == -1: return float("inf") return self.slots - self.occupied_slots(session) + + @staticmethod + @provide_session + def get_team_name(pool_name: str, session=NEW_SESSION) -> str | None: + stmt = select(Team.name).join(Pool, Team.id == Pool.team_id).where(Pool.pool == pool_name) + return session.scalar(stmt) diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 6e9a7537821..62bb268e89d 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -33,10 +33,11 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.models.team import Team from airflow.sdk import SecretCache from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import create_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -57,11 +58,12 @@ class Variable(Base, LoggingMixin): is_encrypted = Column(Boolean, unique=False, default=False) team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) - def __init__(self, key=None, val=None, description=None): + def __init__(self, key=None, val=None, description=None, team_id=None): super().__init__() self.key = key self.val = val self.description = description + self.team_id = team_id @reconstructor def on_db_load(self): @@ -452,3 +454,11 @@ class Variable(Base, LoggingMixin): SecretCache.save_variable(key, var_val) # we save None as well return var_val + + @staticmethod + @provide_session + def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None: + stmt = ( + select(Team.name).join(Variable, Team.id == Variable.team_id).where(Variable.key == variable_key) + ) + return session.scalar(stmt) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py index f407fbc7f8f..20cf233c1c5 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py @@ -23,9 +23,22 @@ from fastapi import HTTPException from jwt import ExpiredSignatureError, InvalidTokenError from airflow.api_fastapi.app import create_app -from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity +from airflow.api_fastapi.auth.managers.models.resource_details import ( + ConnectionDetails, + DagAccessEntity, + PoolDetails, + VariableDetails, +) from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser -from airflow.api_fastapi.core_api.security import is_safe_url, requires_access_dag, resolve_user_from_token +from airflow.api_fastapi.core_api.security import ( + is_safe_url, + requires_access_connection, + requires_access_dag, + requires_access_pool, + requires_access_variable, + resolve_user_from_token, +) +from airflow.models import Connection, Pool, Variable from tests_common.test_utils.config import conf_vars @@ -141,3 +154,78 @@ class TestFastApiSecurity: request = Mock() request.base_url = "https://requesting_server_base_url.com/prefix2" assert is_safe_url(url, request=request) == expected_is_safe + + @pytest.mark.db_test + @pytest.mark.parametrize( + "team_name", + [None, "team1"], + ) + @patch.object(Connection, "get_team_name") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + async def test_requires_access_connection(self, mock_get_auth_manager, mock_get_team_name, team_name): + auth_manager = Mock() + auth_manager.is_authorized_connection.return_value = True + mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params = {"connection_id": "conn_id"} + mock_get_team_name.return_value = team_name + user = Mock() + + requires_access_connection("GET")(fastapi_request, user) + + auth_manager.is_authorized_connection.assert_called_once_with( + method="GET", + details=ConnectionDetails(conn_id="conn_id", team_name=team_name), + user=user, + ) + mock_get_team_name.assert_called_once_with("conn_id") + + @pytest.mark.db_test + @pytest.mark.parametrize( + "team_name", + [None, "team1"], + ) + @patch.object(Variable, "get_team_name") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + async def test_requires_access_variable(self, mock_get_auth_manager, mock_get_team_name, team_name): + auth_manager = Mock() + auth_manager.is_authorized_variable.return_value = True + mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params = {"variable_key": "var_key"} + mock_get_team_name.return_value = team_name + user = Mock() + + requires_access_variable("GET")(fastapi_request, user) + + auth_manager.is_authorized_variable.assert_called_once_with( + method="GET", + details=VariableDetails(key="var_key", team_name=team_name), + user=user, + ) + mock_get_team_name.assert_called_once_with("var_key") + + @pytest.mark.db_test + @pytest.mark.parametrize( + "team_name", + [None, "team1"], + ) + @patch.object(Pool, "get_team_name") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + async def test_requires_access_pool(self, mock_get_auth_manager, mock_get_team_name, team_name): + auth_manager = Mock() + auth_manager.is_authorized_pool.return_value = True + mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params = {"pool_name": "pool"} + mock_get_team_name.return_value = team_name + user = Mock() + + requires_access_pool("GET")(fastapi_request, user) + + auth_manager.is_authorized_pool.assert_called_once_with( + method="GET", + details=PoolDetails(name="pool", team_name=team_name), + user=user, + ) + mock_get_team_name.assert_called_once_with("pool") diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 62c9377e2b6..6fdea1b2d05 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -19,6 +19,7 @@ from __future__ import annotations import re import sys +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -28,6 +29,12 @@ from airflow.models import Connection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType from airflow.sdk.execution_time.comms import ErrorResponse +from tests_common.test_utils.db import clear_db_connections + +if TYPE_CHECKING: + from airflow.models.team import Team + from airflow.settings import Session + class TestConnection: @pytest.mark.parametrize( @@ -355,3 +362,14 @@ class TestConnection: # Verify the backends were called mock_env_backend.assert_called_once_with(conn_id="test_conn") mock_db_backend.assert_called_once_with(conn_id="test_conn") + + @pytest.mark.db_test + def test_get_team_name(self, testing_team: Team, session: Session): + clear_db_connections() + + connection = Connection(conn_id="test_conn", conn_type="test_type", team_id=testing_team.id) + session.add(connection) + session.flush() + + assert Connection.get_team_name("test_conn", session=session) == "testing" + clear_db_connections() diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index cf3814fa9fd..796ed807df8 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -21,7 +21,6 @@ import datetime import logging import os import pickle -import uuid from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -58,7 +57,6 @@ from airflow.models.dagbundle import DagBundleModel from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance as TI -from airflow.models.team import Team from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator @@ -143,18 +141,6 @@ def test_dags_bundle(configure_testing_dag_bundle): yield -@pytest.fixture -def testing_team(): - from airflow.utils.session import create_session - - with create_session() as session: - team = session.query(Team).filter_by(name="testing").one_or_none() - if not team: - team = Team(id=uuid.uuid4(), name="testing") - session.add(team) - yield team - - def _create_dagrun( dag: DAG, *, diff --git a/airflow-core/tests/unit/models/test_pool.py b/airflow-core/tests/unit/models/test_pool.py index 81d4ce66dc7..5f53474ae24 100644 --- a/airflow-core/tests/unit/models/test_pool.py +++ b/airflow-core/tests/unit/models/test_pool.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from airflow import settings @@ -36,6 +38,10 @@ from tests_common.test_utils.db import ( set_default_pool_slots, ) +if TYPE_CHECKING: + from airflow.models.team import Team + from airflow.settings import Session + pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -319,3 +325,10 @@ class TestPool: default_pool = Pool.get_default_pool() assert not Pool.is_default_pool(id=pool.id) assert Pool.is_default_pool(str(default_pool.id)) + + def test_get_team_name(self, testing_team: Team, session: Session): + pool = Pool(pool="test", include_deferred=False, team_id=testing_team.id) + session.add(pool) + session.flush() + + assert Pool.get_team_name("test", session=session) == "testing" diff --git a/airflow-core/tests/unit/models/test_variable.py b/airflow-core/tests/unit/models/test_variable.py index d7a035bbb8d..91fb045e2c5 100644 --- a/airflow-core/tests/unit/models/test_variable.py +++ b/airflow-core/tests/unit/models/test_variable.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging import os +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -31,6 +32,10 @@ from airflow.secrets.metastore import MetastoreBackend from tests_common.test_utils import db from tests_common.test_utils.config import conf_vars +if TYPE_CHECKING: + from airflow.models.team import Team + from airflow.settings import Session + pytestmark = pytest.mark.db_test @@ -311,6 +316,13 @@ class TestVariable: assert c != b + def test_get_team_name(self, testing_team: Team, session: Session): + var = Variable(key="key", val="value", team_id=testing_team.id) + session.add(var) + session.flush() + + assert Variable.get_team_name("key", session=session) == "testing" + @pytest.mark.parametrize( "variable_value, deserialize_json, expected_masked_values", diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index ce588352d4e..9f05f0ec8b0 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -25,6 +25,7 @@ import platform import re import subprocess import sys +import uuid import warnings from collections.abc import Callable, Generator from contextlib import ExitStack, suppress @@ -2672,6 +2673,23 @@ def testing_dag_bundle(): session.add(testing) +@pytest.fixture +def testing_team(): + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + + if AIRFLOW_V_3_0_PLUS: + from airflow.models.team import Team + from airflow.utils.session import create_session + + with create_session() as session: + team = session.query(Team).filter_by(name="testing").one_or_none() + if not team: + team = Team(id=uuid.uuid4(), name="testing") + session.add(team) + session.flush() + yield team + + @pytest.fixture def create_connection_without_db(monkeypatch): """