vincbeck commented on code in PR #61351:
URL: https://github.com/apache/airflow/pull/61351#discussion_r2817415023


##########
providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py:
##########
@@ -411,12 +492,249 @@ def test_is_authorized_dag(
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
         assert result == expected
 
+    _TEAM_SCOPED_PERMISSION_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(id="test", team_name="team-a"), 
"Dag:team-a#GET"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails(conn_id="test", team_name="team-a"),
+                "Connection:team-a#GET",
+            ),
+            (
+                "is_authorized_variable",
+                VariableDetails(key="test", team_name="team-a"),
+                "Variable:team-a#GET",
+            ),
+            ("is_authorized_pool", PoolDetails(name="test", 
team_name="team-a"), "Pool:team-a#GET"),
+        ]
+        if AIRFLOW_V_3_2_PLUS
+        else [
+            pytest.param(
+                None,
+                None,
+                None,
+                marks=pytest.mark.skip(reason="multi_team not supported before 
3.2.0"),
+            )
+        ]
+    )
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission"),
+        _TEAM_SCOPED_PERMISSION_PARAMS,
+    )
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="multi_team not 
supported before 3.2.0")
+    def test_is_authorized_team_scoped_permission(
+        self, auth_manager_multi_team, user, function, details, permission
+    ):
+        if function is None:
+            pytest.skip("multi_team not supported before 3.2.0")
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload(
+            "client_id",
+            permission,
+            {RESOURCE_ID_ATTRIBUTE_NAME: "test"},
+        )
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission", "resource_id"),
+        [
+            ("is_authorized_dag", DagDetails(id="test"), "Dag#GET", "test"),
+            ("is_authorized_connection", ConnectionDetails(conn_id="test"), 
"Connection#GET", "test"),
+            ("is_authorized_variable", VariableDetails(key="test"), 
"Variable#GET", "test"),
+            ("is_authorized_pool", PoolDetails(name="test"), "Pool#GET", 
"test"),
+        ],
+    )
+    def test_is_authorized_team_scoped_no_team_uses_global_permission(
+        self, auth_manager_multi_team, user, function, details, permission, 
resource_id
+    ):
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload(
+            "client_id",
+            permission,
+            {RESOURCE_ID_ATTRIBUTE_NAME: resource_id},
+        )
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    _TEAM_SCOPED_LIST_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(team_name="team-a"), 
"Dag:team-a#LIST"),
+            ("is_authorized_connection", 
ConnectionDetails(team_name="team-a"), "Connection:team-a#LIST"),
+            ("is_authorized_variable", VariableDetails(team_name="team-a"), 
"Variable:team-a#LIST"),
+            ("is_authorized_pool", PoolDetails(team_name="team-a"), 
"Pool:team-a#LIST"),
+        ]
+        if AIRFLOW_V_3_2_PLUS

Review Comment:
   Same



##########
providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py:
##########
@@ -411,12 +492,249 @@ def test_is_authorized_dag(
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
         assert result == expected
 
+    _TEAM_SCOPED_PERMISSION_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(id="test", team_name="team-a"), 
"Dag:team-a#GET"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails(conn_id="test", team_name="team-a"),
+                "Connection:team-a#GET",
+            ),
+            (
+                "is_authorized_variable",
+                VariableDetails(key="test", team_name="team-a"),
+                "Variable:team-a#GET",
+            ),
+            ("is_authorized_pool", PoolDetails(name="test", 
team_name="team-a"), "Pool:team-a#GET"),
+        ]
+        if AIRFLOW_V_3_2_PLUS
+        else [

Review Comment:
   This is confusing. Do you mean `if not AIRFLOW_V_3_2_PLUS`?



##########
providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py:
##########
@@ -411,12 +492,249 @@ def test_is_authorized_dag(
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
         assert result == expected
 
+    _TEAM_SCOPED_PERMISSION_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(id="test", team_name="team-a"), 
"Dag:team-a#GET"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails(conn_id="test", team_name="team-a"),
+                "Connection:team-a#GET",
+            ),
+            (
+                "is_authorized_variable",
+                VariableDetails(key="test", team_name="team-a"),
+                "Variable:team-a#GET",
+            ),
+            ("is_authorized_pool", PoolDetails(name="test", 
team_name="team-a"), "Pool:team-a#GET"),
+        ]
+        if AIRFLOW_V_3_2_PLUS
+        else [

Review Comment:
   And I dont quite understand this test? Can you explain?



##########
providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py:
##########
@@ -411,12 +492,249 @@ def test_is_authorized_dag(
 
         token_url = auth_manager._get_token_url("server_url", "realm")
         payload = auth_manager._get_payload("client_id", permission, 
attributes)
-        headers = auth_manager._get_headers("access_token")
+        headers = auth_manager._get_headers(user.access_token)
         auth_manager.http_session.post.assert_called_once_with(
             token_url, data=payload, headers=headers, timeout=5
         )
         assert result == expected
 
+    _TEAM_SCOPED_PERMISSION_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(id="test", team_name="team-a"), 
"Dag:team-a#GET"),
+            (
+                "is_authorized_connection",
+                ConnectionDetails(conn_id="test", team_name="team-a"),
+                "Connection:team-a#GET",
+            ),
+            (
+                "is_authorized_variable",
+                VariableDetails(key="test", team_name="team-a"),
+                "Variable:team-a#GET",
+            ),
+            ("is_authorized_pool", PoolDetails(name="test", 
team_name="team-a"), "Pool:team-a#GET"),
+        ]
+        if AIRFLOW_V_3_2_PLUS
+        else [
+            pytest.param(
+                None,
+                None,
+                None,
+                marks=pytest.mark.skip(reason="multi_team not supported before 
3.2.0"),
+            )
+        ]
+    )
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission"),
+        _TEAM_SCOPED_PERMISSION_PARAMS,
+    )
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="multi_team not 
supported before 3.2.0")
+    def test_is_authorized_team_scoped_permission(
+        self, auth_manager_multi_team, user, function, details, permission
+    ):
+        if function is None:
+            pytest.skip("multi_team not supported before 3.2.0")
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload(
+            "client_id",
+            permission,
+            {RESOURCE_ID_ATTRIBUTE_NAME: "test"},
+        )
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission", "resource_id"),
+        [
+            ("is_authorized_dag", DagDetails(id="test"), "Dag#GET", "test"),
+            ("is_authorized_connection", ConnectionDetails(conn_id="test"), 
"Connection#GET", "test"),
+            ("is_authorized_variable", VariableDetails(key="test"), 
"Variable#GET", "test"),
+            ("is_authorized_pool", PoolDetails(name="test"), "Pool#GET", 
"test"),
+        ],
+    )
+    def test_is_authorized_team_scoped_no_team_uses_global_permission(
+        self, auth_manager_multi_team, user, function, details, permission, 
resource_id
+    ):
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload(
+            "client_id",
+            permission,
+            {RESOURCE_ID_ATTRIBUTE_NAME: resource_id},
+        )
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    _TEAM_SCOPED_LIST_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(team_name="team-a"), 
"Dag:team-a#LIST"),
+            ("is_authorized_connection", 
ConnectionDetails(team_name="team-a"), "Connection:team-a#LIST"),
+            ("is_authorized_variable", VariableDetails(team_name="team-a"), 
"Variable:team-a#LIST"),
+            ("is_authorized_pool", PoolDetails(team_name="team-a"), 
"Pool:team-a#LIST"),
+        ]
+        if AIRFLOW_V_3_2_PLUS
+        else [
+            pytest.param(
+                None,
+                None,
+                None,
+                marks=pytest.mark.skip(reason="multi_team not supported before 
3.2.0"),
+            )
+        ]
+    )
+
+    @pytest.mark.parametrize(
+        ("function", "permission"),
+        [
+            ("is_authorized_dag", "Dag#LIST"),
+            ("is_authorized_connection", "Connection#LIST"),
+            ("is_authorized_variable", "Variable#LIST"),
+            ("is_authorized_pool", "Pool#LIST"),
+        ],
+    )
+    def 
test_is_authorized_team_scoped_list_multi_team_without_team_global_list(
+        self, auth_manager_multi_team, user, function, permission
+    ):
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload("client_id", 
permission, {})
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    @pytest.mark.parametrize(
+        ("function", "details", "permission"),
+        _TEAM_SCOPED_LIST_PARAMS,
+    )
+    @pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="multi_team not 
supported before 3.2.0")
+    def test_is_authorized_team_scoped_list_team_scoped_permission(
+        self, auth_manager_multi_team, user, function, details, permission
+    ):
+        if function is None:
+            pytest.skip("multi_team not supported before 3.2.0")
+        user.access_token = _build_access_token({"groups": ["team-a"]})
+        mock_response = Mock()
+        mock_response.status_code = 200
+        auth_manager_multi_team.http_session.post = 
Mock(return_value=mock_response)
+
+        result = getattr(auth_manager_multi_team, function)(method="GET", 
user=user, details=details)
+
+        token_url = auth_manager_multi_team._get_token_url("server_url", 
"realm")
+        payload = auth_manager_multi_team._get_payload("client_id", 
permission, {})
+        headers = auth_manager_multi_team._get_headers(user.access_token)
+        auth_manager_multi_team.http_session.post.assert_called_once_with(
+            token_url, data=payload, headers=headers, timeout=5
+        )
+        assert result is True
+
+    def test_filter_authorized_dag_ids_team_mismatch(self, 
auth_manager_multi_team, user):
+        if "team_name" not in 
inspect.signature(auth_manager_multi_team.filter_authorized_dag_ids).parameters:
+            pytest.skip("team_name not supported by filter_authorized_dag_ids 
in this Airflow version.")
+
+        with patch.object(KeycloakAuthManager, "is_authorized_dag", 
return_value=False) as mock_is_authorized:
+            result = auth_manager_multi_team.filter_authorized_dag_ids(
+                dag_ids={"dag-a"}, user=user, team_name="team-b"
+            )
+
+            mock_is_authorized.assert_called_once()
+            assert result == set()
+
+    def test_filter_authorized_dag_ids_team_match(self, 
auth_manager_multi_team, user):
+        if "team_name" not in 
inspect.signature(auth_manager_multi_team.filter_authorized_dag_ids).parameters:
+            pytest.skip("team_name not supported by filter_authorized_dag_ids 
in this Airflow version.")
+
+        with patch.object(KeycloakAuthManager, "is_authorized_dag", 
return_value=True) as mock_is_authorized:
+            result = auth_manager_multi_team.filter_authorized_dag_ids(
+                dag_ids={"dag-a"}, user=user, team_name="team-a"
+            )
+
+            mock_is_authorized.assert_called_once()
+            assert result == {"dag-a"}
+
+    def test_filter_authorized_pools_no_team_returns_empty(self, 
auth_manager_multi_team, user):
+        if not hasattr(auth_manager_multi_team, "filter_authorized_pools"):
+            pytest.skip("filter_authorized_pools not available in this Airflow 
version.")
+
+        with patch.object(
+            KeycloakAuthManager, "is_authorized_pool", return_value=False
+        ) as mock_is_authorized:
+            result = auth_manager_multi_team.filter_authorized_pools(
+                pool_names={"pool-a"}, user=user, team_name=None
+            )
+
+            mock_is_authorized.assert_called_once()
+            assert result == set()
+
+    _TEAM_SCOPED_LIST_MISMATCH_PARAMS = (
+        [
+            ("is_authorized_dag", DagDetails(team_name="team-b")),
+            ("is_authorized_connection", 
ConnectionDetails(team_name="team-b")),
+            ("is_authorized_variable", VariableDetails(team_name="team-b")),
+            ("is_authorized_pool", PoolDetails(team_name="team-b")),
+        ]
+        if AIRFLOW_V_3_2_PLUS

Review Comment:
   Same



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to