This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9a895cee53 Pass team name to `is_authorized_connection`, 
`is_authorized_variable` and `is_authorized_pool` in Airflow API (#55193)
e9a895cee53 is described below

commit e9a895cee533dd91556f84baba06f60bbd75ece6
Author: Vincent <97131062+vincb...@users.noreply.github.com>
AuthorDate: Thu Sep 4 08:54:31 2025 -0400

    Pass team name to `is_authorized_connection`, `is_authorized_variable` and 
`is_authorized_pool` in Airflow API (#55193)
---
 .../auth/managers/models/resource_details.py       |  3 +
 .../src/airflow/api_fastapi/core_api/security.py   | 12 ++-
 airflow-core/src/airflow/models/connection.py      | 16 +++-
 airflow-core/src/airflow/models/pool.py            |  7 ++
 airflow-core/src/airflow/models/variable.py        | 14 +++-
 .../unit/api_fastapi/core_api/test_security.py     | 92 +++++++++++++++++++++-
 airflow-core/tests/unit/models/test_connection.py  | 18 +++++
 airflow-core/tests/unit/models/test_dag.py         | 14 ----
 airflow-core/tests/unit/models/test_pool.py        | 13 +++
 airflow-core/tests/unit/models/test_variable.py    | 12 +++
 devel-common/src/tests_common/pytest_plugin.py     | 18 +++++
 11 files changed, 197 insertions(+), 22 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py 
b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
index aed68f27aa8..afce15e6b0a 100644
--- 
a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
+++ 
b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
@@ -35,6 +35,7 @@ class ConnectionDetails:
     """Represents the details of a connection."""
 
     conn_id: str | None = None
+    team_name: str | None = None
 
 
 @dataclass
@@ -71,6 +72,7 @@ class PoolDetails:
     """Represents the details of a pool."""
 
     name: str | None = None
+    team_name: str | None = None
 
 
 @dataclass
@@ -78,6 +80,7 @@ class VariableDetails:
     """Represents the details of a variable."""
 
     key: str | None = None
+    team_name: str | None = None
 
 
 class AccessView(Enum):
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py 
b/airflow-core/src/airflow/api_fastapi/core_api/security.py
index 900ad3fbb6a..6053900ce33 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py
@@ -42,6 +42,7 @@ from 
airflow.api_fastapi.auth.managers.models.resource_details import (
 )
 from airflow.api_fastapi.core_api.base import OrmClause
 from airflow.configuration import conf
+from airflow.models import Connection, Pool, Variable
 from airflow.models.dag import DagModel, DagRun, DagTag
 from airflow.models.dagwarning import DagWarning
 from airflow.models.taskinstance import TaskInstance as TI
@@ -223,10 +224,11 @@ def requires_access_pool(method: ResourceMethod) -> 
Callable[[Request, BaseUser]
         user: GetUserDep,
     ) -> None:
         pool_name = request.path_params.get("pool_name")
+        team_name = Pool.get_team_name(pool_name) if pool_name else None
 
         _requires_access(
             is_authorized_callback=lambda: 
get_auth_manager().is_authorized_pool(
-                method=method, details=PoolDetails(name=pool_name), user=user
+                method=method, details=PoolDetails(name=pool_name, 
team_name=team_name), user=user
             )
         )
 
@@ -239,10 +241,13 @@ def requires_access_connection(method: ResourceMethod) -> 
Callable[[Request, Bas
         user: GetUserDep,
     ) -> None:
         connection_id = request.path_params.get("connection_id")
+        team_name = Connection.get_team_name(connection_id) if connection_id 
else None
 
         _requires_access(
             is_authorized_callback=lambda: 
get_auth_manager().is_authorized_connection(
-                method=method, 
details=ConnectionDetails(conn_id=connection_id), user=user
+                method=method,
+                details=ConnectionDetails(conn_id=connection_id, 
team_name=team_name),
+                user=user,
             )
         )
 
@@ -273,10 +278,11 @@ def requires_access_variable(method: ResourceMethod) -> 
Callable[[Request, BaseU
         user: GetUserDep,
     ) -> None:
         variable_key: str | None = request.path_params.get("variable_key")
+        team_name = Variable.get_team_name(variable_key) if variable_key else 
None
 
         _requires_access(
             is_authorized_callback=lambda: 
get_auth_manager().is_authorized_variable(
-                method=method, details=VariableDetails(key=variable_key), 
user=user
+                method=method, details=VariableDetails(key=variable_key, 
team_name=team_name), user=user
             ),
         )
 
diff --git a/airflow-core/src/airflow/models/connection.py 
b/airflow-core/src/airflow/models/connection.py
index 33d292a645a..1567cab1dd2 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -27,7 +27,7 @@ from json import JSONDecodeError
 from typing import Any
 from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
 
-from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text
+from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, 
select
 from sqlalchemy.orm import declared_attr, reconstructor, synonym
 from sqlalchemy_utils import UUIDType
 
@@ -36,10 +36,12 @@ from airflow.configuration import ensure_secrets_loaded
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.models.base import ID_LEN, Base
 from airflow.models.crypto import get_fernet
+from airflow.models.team import Team
 from airflow.sdk import SecretCache
 from airflow.utils.helpers import prune_dict
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.module_loading import import_string
+from airflow.utils.session import NEW_SESSION, provide_session
 
 log = logging.getLogger(__name__)
 # sanitize the `conn_id` pattern by allowing alphanumeric characters plus
@@ -150,6 +152,7 @@ class Connection(Base, LoggingMixin):
         port: int | None = None,
         extra: str | dict | None = None,
         uri: str | None = None,
+        team_id: str | None = None,
     ):
         super().__init__()
         self.conn_id = sanitize_conn_id(conn_id)
@@ -178,6 +181,7 @@ class Connection(Base, LoggingMixin):
         if self.password:
             mask_secret(self.password)
             mask_secret(quote(self.password))
+        self.team_id = team_id
 
     @staticmethod
     def _validate_extra(extra, conn_id) -> None:
@@ -584,3 +588,13 @@ class Connection(Base, LoggingMixin):
         conn_repr = self.to_dict(prune_empty=True, validate=False)
         conn_repr.pop("conn_id", None)
         return json.dumps(conn_repr)
+
+    @staticmethod
+    @provide_session
+    def get_team_name(connection_id: str, session=NEW_SESSION) -> str | None:
+        stmt = (
+            select(Team.name)
+            .join(Connection, Team.id == Connection.team_id)
+            .where(Connection.conn_id == connection_id)
+        )
+        return session.scalar(stmt)
diff --git a/airflow-core/src/airflow/models/pool.py 
b/airflow-core/src/airflow/models/pool.py
index 2a8b9b57c3f..7a205036530 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -24,6 +24,7 @@ from sqlalchemy_utils import UUIDType
 
 from airflow.exceptions import AirflowException, PoolNotFound
 from airflow.models.base import Base
+from airflow.models.team import Team
 from airflow.ti_deps.dependencies_states import EXECUTION_STATES
 from airflow.utils.db import exists_query
 from airflow.utils.session import NEW_SESSION, provide_session
@@ -352,3 +353,9 @@ class Pool(Base):
         if self.slots == -1:
             return float("inf")
         return self.slots - self.occupied_slots(session)
+
+    @staticmethod
+    @provide_session
+    def get_team_name(pool_name: str, session=NEW_SESSION) -> str | None:
+        stmt = select(Team.name).join(Pool, Team.id == 
Pool.team_id).where(Pool.pool == pool_name)
+        return session.scalar(stmt)
diff --git a/airflow-core/src/airflow/models/variable.py 
b/airflow-core/src/airflow/models/variable.py
index 6e9a7537821..62bb268e89d 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -33,10 +33,11 @@ from airflow._shared.secrets_masker import mask_secret
 from airflow.configuration import ensure_secrets_loaded
 from airflow.models.base import ID_LEN, Base
 from airflow.models.crypto import get_fernet
+from airflow.models.team import Team
 from airflow.sdk import SecretCache
 from airflow.secrets.metastore import MetastoreBackend
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import create_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -57,11 +58,12 @@ class Variable(Base, LoggingMixin):
     is_encrypted = Column(Boolean, unique=False, default=False)
     team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), 
nullable=True)
 
-    def __init__(self, key=None, val=None, description=None):
+    def __init__(self, key=None, val=None, description=None, team_id=None):
         super().__init__()
         self.key = key
         self.val = val
         self.description = description
+        self.team_id = team_id
 
     @reconstructor
     def on_db_load(self):
@@ -452,3 +454,11 @@ class Variable(Base, LoggingMixin):
 
         SecretCache.save_variable(key, var_val)  # we save None as well
         return var_val
+
+    @staticmethod
+    @provide_session
+    def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
+        stmt = (
+            select(Team.name).join(Variable, Team.id == 
Variable.team_id).where(Variable.key == variable_key)
+        )
+        return session.scalar(stmt)
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py 
b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
index f407fbc7f8f..20cf233c1c5 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
@@ -23,9 +23,22 @@ from fastapi import HTTPException
 from jwt import ExpiredSignatureError, InvalidTokenError
 
 from airflow.api_fastapi.app import create_app
-from airflow.api_fastapi.auth.managers.models.resource_details import 
DagAccessEntity
+from airflow.api_fastapi.auth.managers.models.resource_details import (
+    ConnectionDetails,
+    DagAccessEntity,
+    PoolDetails,
+    VariableDetails,
+)
 from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
-from airflow.api_fastapi.core_api.security import is_safe_url, 
requires_access_dag, resolve_user_from_token
+from airflow.api_fastapi.core_api.security import (
+    is_safe_url,
+    requires_access_connection,
+    requires_access_dag,
+    requires_access_pool,
+    requires_access_variable,
+    resolve_user_from_token,
+)
+from airflow.models import Connection, Pool, Variable
 
 from tests_common.test_utils.config import conf_vars
 
@@ -141,3 +154,78 @@ class TestFastApiSecurity:
         request = Mock()
         request.base_url = "https://requesting_server_base_url.com/prefix2";
         assert is_safe_url(url, request=request) == expected_is_safe
+
+    @pytest.mark.db_test
+    @pytest.mark.parametrize(
+        "team_name",
+        [None, "team1"],
+    )
+    @patch.object(Connection, "get_team_name")
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    async def test_requires_access_connection(self, mock_get_auth_manager, 
mock_get_team_name, team_name):
+        auth_manager = Mock()
+        auth_manager.is_authorized_connection.return_value = True
+        mock_get_auth_manager.return_value = auth_manager
+        fastapi_request = Mock()
+        fastapi_request.path_params = {"connection_id": "conn_id"}
+        mock_get_team_name.return_value = team_name
+        user = Mock()
+
+        requires_access_connection("GET")(fastapi_request, user)
+
+        auth_manager.is_authorized_connection.assert_called_once_with(
+            method="GET",
+            details=ConnectionDetails(conn_id="conn_id", team_name=team_name),
+            user=user,
+        )
+        mock_get_team_name.assert_called_once_with("conn_id")
+
+    @pytest.mark.db_test
+    @pytest.mark.parametrize(
+        "team_name",
+        [None, "team1"],
+    )
+    @patch.object(Variable, "get_team_name")
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    async def test_requires_access_variable(self, mock_get_auth_manager, 
mock_get_team_name, team_name):
+        auth_manager = Mock()
+        auth_manager.is_authorized_variable.return_value = True
+        mock_get_auth_manager.return_value = auth_manager
+        fastapi_request = Mock()
+        fastapi_request.path_params = {"variable_key": "var_key"}
+        mock_get_team_name.return_value = team_name
+        user = Mock()
+
+        requires_access_variable("GET")(fastapi_request, user)
+
+        auth_manager.is_authorized_variable.assert_called_once_with(
+            method="GET",
+            details=VariableDetails(key="var_key", team_name=team_name),
+            user=user,
+        )
+        mock_get_team_name.assert_called_once_with("var_key")
+
+    @pytest.mark.db_test
+    @pytest.mark.parametrize(
+        "team_name",
+        [None, "team1"],
+    )
+    @patch.object(Pool, "get_team_name")
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    async def test_requires_access_pool(self, mock_get_auth_manager, 
mock_get_team_name, team_name):
+        auth_manager = Mock()
+        auth_manager.is_authorized_pool.return_value = True
+        mock_get_auth_manager.return_value = auth_manager
+        fastapi_request = Mock()
+        fastapi_request.path_params = {"pool_name": "pool"}
+        mock_get_team_name.return_value = team_name
+        user = Mock()
+
+        requires_access_pool("GET")(fastapi_request, user)
+
+        auth_manager.is_authorized_pool.assert_called_once_with(
+            method="GET",
+            details=PoolDetails(name="pool", team_name=team_name),
+            user=user,
+        )
+        mock_get_team_name.assert_called_once_with("pool")
diff --git a/airflow-core/tests/unit/models/test_connection.py 
b/airflow-core/tests/unit/models/test_connection.py
index 62c9377e2b6..6fdea1b2d05 100644
--- a/airflow-core/tests/unit/models/test_connection.py
+++ b/airflow-core/tests/unit/models/test_connection.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import re
 import sys
+from typing import TYPE_CHECKING
 from unittest import mock
 
 import pytest
@@ -28,6 +29,12 @@ from airflow.models import Connection
 from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
 from airflow.sdk.execution_time.comms import ErrorResponse
 
+from tests_common.test_utils.db import clear_db_connections
+
+if TYPE_CHECKING:
+    from airflow.models.team import Team
+    from airflow.settings import Session
+
 
 class TestConnection:
     @pytest.mark.parametrize(
@@ -355,3 +362,14 @@ class TestConnection:
         # Verify the backends were called
         mock_env_backend.assert_called_once_with(conn_id="test_conn")
         mock_db_backend.assert_called_once_with(conn_id="test_conn")
+
+    @pytest.mark.db_test
+    def test_get_team_name(self, testing_team: Team, session: Session):
+        clear_db_connections()
+
+        connection = Connection(conn_id="test_conn", conn_type="test_type", 
team_id=testing_team.id)
+        session.add(connection)
+        session.flush()
+
+        assert Connection.get_team_name("test_conn", session=session) == 
"testing"
+        clear_db_connections()
diff --git a/airflow-core/tests/unit/models/test_dag.py 
b/airflow-core/tests/unit/models/test_dag.py
index cf3814fa9fd..796ed807df8 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -21,7 +21,6 @@ import datetime
 import logging
 import os
 import pickle
-import uuid
 from datetime import timedelta
 from pathlib import Path
 from typing import TYPE_CHECKING
@@ -58,7 +57,6 @@ from airflow.models.dagbundle import DagBundleModel
 from airflow.models.dagrun import DagRun
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskinstance import TaskInstance as TI
-from airflow.models.team import Team
 from airflow.providers.standard.operators.bash import BashOperator
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.python import PythonOperator
@@ -143,18 +141,6 @@ def test_dags_bundle(configure_testing_dag_bundle):
         yield
 
 
-@pytest.fixture
-def testing_team():
-    from airflow.utils.session import create_session
-
-    with create_session() as session:
-        team = session.query(Team).filter_by(name="testing").one_or_none()
-        if not team:
-            team = Team(id=uuid.uuid4(), name="testing")
-            session.add(team)
-        yield team
-
-
 def _create_dagrun(
     dag: DAG,
     *,
diff --git a/airflow-core/tests/unit/models/test_pool.py 
b/airflow-core/tests/unit/models/test_pool.py
index 81d4ce66dc7..5f53474ae24 100644
--- a/airflow-core/tests/unit/models/test_pool.py
+++ b/airflow-core/tests/unit/models/test_pool.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
+
 import pytest
 
 from airflow import settings
@@ -36,6 +38,10 @@ from tests_common.test_utils.db import (
     set_default_pool_slots,
 )
 
+if TYPE_CHECKING:
+    from airflow.models.team import Team
+    from airflow.settings import Session
+
 pytestmark = pytest.mark.db_test
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -319,3 +325,10 @@ class TestPool:
         default_pool = Pool.get_default_pool()
         assert not Pool.is_default_pool(id=pool.id)
         assert Pool.is_default_pool(str(default_pool.id))
+
+    def test_get_team_name(self, testing_team: Team, session: Session):
+        pool = Pool(pool="test", include_deferred=False, 
team_id=testing_team.id)
+        session.add(pool)
+        session.flush()
+
+        assert Pool.get_team_name("test", session=session) == "testing"
diff --git a/airflow-core/tests/unit/models/test_variable.py 
b/airflow-core/tests/unit/models/test_variable.py
index d7a035bbb8d..91fb045e2c5 100644
--- a/airflow-core/tests/unit/models/test_variable.py
+++ b/airflow-core/tests/unit/models/test_variable.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import logging
 import os
+from typing import TYPE_CHECKING
 from unittest import mock
 
 import pytest
@@ -31,6 +32,10 @@ from airflow.secrets.metastore import MetastoreBackend
 from tests_common.test_utils import db
 from tests_common.test_utils.config import conf_vars
 
+if TYPE_CHECKING:
+    from airflow.models.team import Team
+    from airflow.settings import Session
+
 pytestmark = pytest.mark.db_test
 
 
@@ -311,6 +316,13 @@ class TestVariable:
 
         assert c != b
 
+    def test_get_team_name(self, testing_team: Team, session: Session):
+        var = Variable(key="key", val="value", team_id=testing_team.id)
+        session.add(var)
+        session.flush()
+
+        assert Variable.get_team_name("key", session=session) == "testing"
+
 
 @pytest.mark.parametrize(
     "variable_value, deserialize_json, expected_masked_values",
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index ce588352d4e..9f05f0ec8b0 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -25,6 +25,7 @@ import platform
 import re
 import subprocess
 import sys
+import uuid
 import warnings
 from collections.abc import Callable, Generator
 from contextlib import ExitStack, suppress
@@ -2672,6 +2673,23 @@ def testing_dag_bundle():
                 session.add(testing)
 
 
+@pytest.fixture
+def testing_team():
+    from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+    if AIRFLOW_V_3_0_PLUS:
+        from airflow.models.team import Team
+        from airflow.utils.session import create_session
+
+        with create_session() as session:
+            team = session.query(Team).filter_by(name="testing").one_or_none()
+            if not team:
+                team = Team(id=uuid.uuid4(), name="testing")
+                session.add(team)
+                session.flush()
+            yield team
+
+
 @pytest.fixture
 def create_connection_without_db(monkeypatch):
     """

Reply via email to