This is an automated email from the ASF dual-hosted git repository.

aminghadersohi pushed a commit to branch mcp-rbac-tool-visibility
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 842df5ee774fb251a9bc91eaad164baae20e689d
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Wed May 20 17:55:27 2026 +0000

    fix(mcp): fix 4 failing unit tests and ruff import error in RBAC tool 
visibility
    
    - Fix ruff error: consolidate contextlib imports into single from-import
    - Fix test patch targets: middleware tests must patch middleware module
      after imports were promoted to module level (not auth module)
    - Fix _tool_allowed_for_current_user: pass public tools through when
      user resolution fails (only hide tools with _class_permission_name)
    
    Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
---
 superset/mcp_service/auth.py                    |  5 +-
 superset/mcp_service/server.py                  |  5 +-
 tests/unit_tests/mcp_service/test_middleware.py | 83 +++++++++++++------------
 3 files changed, 49 insertions(+), 44 deletions(-)

diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py
index e8cb675228a..69d8e625531 100644
--- a/superset/mcp_service/auth.py
+++ b/superset/mcp_service/auth.py
@@ -44,9 +44,8 @@ Configuration:
 - MCP_DEV_USERNAME: Fallback username for development
 """
 
-import contextlib
 import logging
-from contextlib import AbstractContextManager
+from contextlib import AbstractContextManager, nullcontext
 from typing import Any, Callable, TYPE_CHECKING, TypeVar
 
 from flask import current_app, g, has_app_context, has_request_context
@@ -659,7 +658,7 @@ def _get_app_context_manager() -> 
AbstractContextManager[None]:
     ``RBACToolVisibilityMiddleware`` (tools/list filtering).
     """
     if has_request_context():
-        return contextlib.nullcontext()
+        return nullcontext()
     if has_app_context():
         # Push a new context for the CURRENT app (not get_flask_app()
         # which may return a different instance in test environments).
diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py
index 9d3c8d5f350..92166606e43 100644
--- a/superset/mcp_service/server.py
+++ b/superset/mcp_service/server.py
@@ -411,7 +411,10 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
             try:
                 g.user = get_user_from_request()
             except (ValueError, PermissionError):
-                return False
+                # Can't resolve user; only hide protected tools. Public tools
+                # (no _class_permission_name) pass through regardless.
+                func = getattr(tool, "fn", tool)
+                return not getattr(func, "_class_permission_name", None)
 
         return is_tool_visible_to_current_user(tool)
     except (AttributeError, RuntimeError, ValueError):
diff --git a/tests/unit_tests/mcp_service/test_middleware.py 
b/tests/unit_tests/mcp_service/test_middleware.py
index 00e3f9457f0..4056a1f6b0b 100644
--- a/tests/unit_tests/mcp_service/test_middleware.py
+++ b/tests/unit_tests/mcp_service/test_middleware.py
@@ -73,7 +73,7 @@ class TestResponseSizeGuardMiddleware:
         )
         assert middleware.excluded_tools == {"health_check"}
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_allows_small_response(self) -> None:
         """Should allow responses under token limit."""
         middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -96,7 +96,7 @@ class TestResponseSizeGuardMiddleware:
         assert result == small_response
         call_next.assert_called_once_with(context)
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_blocks_large_response(self) -> None:
         """Should block responses over token limit."""
         middleware = ResponseSizeGuardMiddleware(token_limit=100)  # Very low 
limit
@@ -124,7 +124,7 @@ class TestResponseSizeGuardMiddleware:
         assert "Response too large" in error_message
         assert "limit" in error_message.lower()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_skips_excluded_tools(self) -> None:
         """Should skip checking for excluded tools."""
         middleware = ResponseSizeGuardMiddleware(
@@ -144,7 +144,7 @@ class TestResponseSizeGuardMiddleware:
         result = await middleware.on_call_tool(context, call_next)
         assert result == large_response
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_logs_warning_at_threshold(self) -> None:
         """Should log warning when approaching limit.
 
@@ -180,7 +180,7 @@ class TestResponseSizeGuardMiddleware:
         # Should log warning
         mock_logger.warning.assert_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_error_includes_suggestions(self) -> None:
         """Should include suggestions in error message."""
         middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -205,7 +205,7 @@ class TestResponseSizeGuardMiddleware:
         # Should suggest reducing page_size
         assert "page_size" in error_message.lower() or "limit" in 
error_message.lower()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_logs_size_exceeded_event(self) -> None:
         """Should log to event logger when size exceeded."""
         middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -229,7 +229,7 @@ class TestResponseSizeGuardMiddleware:
         call_args = mock_event_logger.log.call_args
         assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_truncates_info_tool_instead_of_blocking(self) -> None:
         """Should truncate info tool responses instead of blocking them."""
         middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -258,7 +258,7 @@ class TestResponseSizeGuardMiddleware:
         assert result["_response_truncated"] is True
         assert "[truncated" in result["description"]
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_truncates_chart_info_with_large_form_data(self) -> None:
         """Should truncate get_chart_info with large form_data."""
         middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -284,7 +284,7 @@ class TestResponseSizeGuardMiddleware:
         assert result["id"] == 1
         assert result["_response_truncated"] is True
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_still_blocks_non_info_tools(self) -> None:
         """Should still block non-info tools that exceed limit."""
         middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -303,7 +303,7 @@ class TestResponseSizeGuardMiddleware:
         ):
             await middleware.on_call_tool(context, call_next)
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_logs_truncation_event(self) -> None:
         """Should log mcp_response_truncated event on successful truncation."""
         middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -591,7 +591,7 @@ class TestToolResultWrapping:
             meta=meta,
         )
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
         """Truncate a ToolResult-wrapped info response and return a 
ToolResult."""
         from fastmcp.tools.tool import ToolResult
@@ -620,7 +620,7 @@ class TestToolResultWrapping:
         assert reparsed["_response_truncated"] is True
         assert "[truncated" in reparsed["description"]
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_small_tool_result_passes_through_unchanged(self) -> None:
         """Should return the original ToolResult when within the token 
limit."""
 
@@ -641,7 +641,7 @@ class TestToolResultWrapping:
 
         assert result is tool_result
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_large_non_info_tool_result_is_blocked(self) -> None:
         """Should raise ToolError for a non-info ToolResult that exceeds the 
limit."""
         middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -662,7 +662,7 @@ class TestToolResultWrapping:
         ):
             await middleware.on_call_tool(context, call_next)
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_meta_preserved_after_truncation(self) -> None:
         """Should preserve the original ToolResult meta through truncation."""
         from fastmcp.tools.tool import ToolResult
@@ -694,7 +694,7 @@ class TestToolResultWrapping:
 class TestMiddlewareIntegration:
     """Integration tests for middleware behavior."""
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_pydantic_model_response(self) -> None:
         """Should handle Pydantic model responses."""
         from pydantic import BaseModel
@@ -720,7 +720,7 @@ class TestMiddlewareIntegration:
 
         assert result == response
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_list_response(self) -> None:
         """Should handle list responses."""
         middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -740,7 +740,7 @@ class TestMiddlewareIntegration:
 
         assert result == response
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_string_response(self) -> None:
         """Should handle string responses."""
         middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -859,7 +859,7 @@ class TestIsUserError:
 class TestGlobalErrorHandlerLogLevels:
     """Test that GlobalErrorHandlerMiddleware logs at correct levels."""
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_user_error_logs_warning(self) -> None:
         """User errors (e.g. ValueError) should log at WARNING."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -882,7 +882,7 @@ class TestGlobalErrorHandlerLogLevels:
         mock_logger.warning.assert_called()
         mock_logger.error.assert_not_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_system_error_logs_error(self) -> None:
         """System errors (OperationalError, generic Exception) should log at 
ERROR."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -904,7 +904,7 @@ class TestGlobalErrorHandlerLogLevels:
         # Should log at ERROR
         mock_logger.error.assert_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_unexpected_error_logs_error(self) -> None:
         """Truly unexpected errors should log at ERROR with error_id."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -926,7 +926,7 @@ class TestGlobalErrorHandlerLogLevels:
         # Should log at ERROR (both the classification log and the error_id 
log)
         assert mock_logger.error.call_count >= 1
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_event_logger_includes_severity(self) -> None:
         """Event logger payload should include severity field."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -949,7 +949,7 @@ class TestGlobalErrorHandlerLogLevels:
         payload = mock_event_logger.log.call_args.kwargs["curated_payload"]
         assert payload["severity"] == "warning"
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_permission_error_logs_warning(self) -> None:
         """PermissionError should log at WARNING — agents are expected to
         try tools they lack access to."""
@@ -972,7 +972,7 @@ class TestGlobalErrorHandlerLogLevels:
         mock_logger.warning.assert_called()
         mock_logger.error.assert_not_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_connection_error_logs_error(self) -> None:
         """ConnectionError should log at ERROR — infrastructure issue."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -993,7 +993,7 @@ class TestGlobalErrorHandlerLogLevels:
 
         mock_logger.error.assert_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_superset_exception_4xx_logs_warning(self) -> None:
         """SupersetException with 4xx status should log at WARNING."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -1017,7 +1017,7 @@ class TestGlobalErrorHandlerLogLevels:
         mock_logger.warning.assert_called()
         mock_logger.error.assert_not_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_superset_exception_5xx_logs_error(self) -> None:
         """SupersetException with 5xx status should log at ERROR."""
         middleware = GlobalErrorHandlerMiddleware()
@@ -1041,7 +1041,7 @@ class TestGlobalErrorHandlerLogLevels:
 
         mock_logger.error.assert_called()
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_mcp_permission_denied_error_becomes_tool_error(self) -> 
None:
         """MCPPermissionDeniedError must convert to ToolError, not a generic 
error."""
         from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1070,7 +1070,7 @@ class TestGlobalErrorHandlerLogLevels:
         assert "can_write" in str(exc_info.value)
         assert "Dashboard" in str(exc_info.value)
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_mcp_permission_denied_error_is_user_error(self) -> None:
         """MCPPermissionDeniedError must be classified as a user error 
(WARNING)."""
         from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1081,7 +1081,7 @@ class TestGlobalErrorHandlerLogLevels:
         )
         assert _is_user_error(error) is True
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_mcp_permission_denied_error_logs_at_warning(self) -> None:
         """MCPPermissionDeniedError should log at WARNING, not ERROR."""
         from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1121,7 +1121,7 @@ class TestRBACToolVisibilityMiddleware:
         tool.name = name
         return tool
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_fails_open_on_exception(self) -> None:
         """Returns all tools when get_flask_app raises (fail open)."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1138,7 +1138,7 @@ class TestRBACToolVisibilityMiddleware:
 
         assert result == tools
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_fails_open_when_user_is_none(self, app) -> None:
         """Returns all tools when get_user_from_request returns None."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1151,13 +1151,16 @@ class TestRBACToolVisibilityMiddleware:
             patch(
                 "superset.mcp_service.flask_singleton.get_flask_app", 
return_value=app
             ),
-            patch("superset.mcp_service.auth.get_user_from_request", 
return_value=None),
+            patch(
+                "superset.mcp_service.middleware.get_user_from_request",
+                return_value=None,
+            ),
         ):
             result = await middleware.on_list_tools(MagicMock(), call_next)
 
         assert result == tools
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_filters_tools_by_rbac(self, app) -> None:
         """Tools denied by is_tool_visible_to_current_user are removed."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1178,11 +1181,11 @@ class TestRBACToolVisibilityMiddleware:
                 "superset.mcp_service.flask_singleton.get_flask_app", 
return_value=app
             ),
             patch(
-                "superset.mcp_service.auth.get_user_from_request",
+                "superset.mcp_service.middleware.get_user_from_request",
                 return_value=mock_user,
             ),
             patch(
-                "superset.mcp_service.auth.is_tool_visible_to_current_user",
+                
"superset.mcp_service.middleware.is_tool_visible_to_current_user",
                 side_effect=_visible,
             ),
         ):
@@ -1191,7 +1194,7 @@ class TestRBACToolVisibilityMiddleware:
         assert read_tool in result
         assert write_tool not in result
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_fails_closed_on_permission_error(self, app) -> None:
         """Returns empty list when credentials are invalid 
(PermissionError)."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1205,7 +1208,7 @@ class TestRBACToolVisibilityMiddleware:
                 "superset.mcp_service.flask_singleton.get_flask_app", 
return_value=app
             ),
             patch(
-                "superset.mcp_service.auth.get_user_from_request",
+                "superset.mcp_service.middleware.get_user_from_request",
                 side_effect=PermissionError("Invalid API key"),
             ),
         ):
@@ -1213,7 +1216,7 @@ class TestRBACToolVisibilityMiddleware:
 
         assert result == []
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_fails_closed_on_bad_credentials_value_error(self, app) -> 
None:
         """Returns empty list when auth was attempted but user not found."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1227,7 +1230,7 @@ class TestRBACToolVisibilityMiddleware:
                 "superset.mcp_service.flask_singleton.get_flask_app", 
return_value=app
             ),
             patch(
-                "superset.mcp_service.auth.get_user_from_request",
+                "superset.mcp_service.middleware.get_user_from_request",
                 side_effect=ValueError("User 'ghost' not found in database"),
             ),
         ):
@@ -1235,7 +1238,7 @@ class TestRBACToolVisibilityMiddleware:
 
         assert result == []
 
-    @pytest.mark.asyncio
+    @pytest.mark.asyncio()
     async def test_fails_open_when_no_auth_configured(self, app) -> None:
         """Returns all tools when no auth source is configured at all."""
         from superset.mcp_service.middleware import 
RBACToolVisibilityMiddleware
@@ -1249,7 +1252,7 @@ class TestRBACToolVisibilityMiddleware:
                 "superset.mcp_service.flask_singleton.get_flask_app", 
return_value=app
             ),
             patch(
-                "superset.mcp_service.auth.get_user_from_request",
+                "superset.mcp_service.middleware.get_user_from_request",
                 side_effect=ValueError("No authenticated user found"),
             ),
         ):

Reply via email to