This is an automated email from the ASF dual-hosted git repository.

aminghadersohi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new b6f545e61e1 feat(mcp): resolve call_tool proxy name and capture 
error_type in logging (#38915)
b6f545e61e1 is described below

commit b6f545e61e1fbbc943e9dd5cec491c69a37a13c7
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue May 26 11:37:37 2026 -0700

    feat(mcp): resolve call_tool proxy name and capture error_type in logging 
(#38915)
    
    Co-authored-by: Amin Ghadersohi <[email protected]>
---
 superset/mcp_service/middleware.py                 |  85 +++++++++---
 .../mcp_service/test_middleware_logging.py         | 153 ++++++++++++++++++---
 2 files changed, 202 insertions(+), 36 deletions(-)

diff --git a/superset/mcp_service/middleware.py 
b/superset/mcp_service/middleware.py
index 844b5ff59d0..685022c3a9f 100644
--- a/superset/mcp_service/middleware.py
+++ b/superset/mcp_service/middleware.py
@@ -188,10 +188,15 @@ def _sanitize_params(params: dict[str, Any]) -> dict[str, 
Any]:
     """Remove sensitive fields from params before logging."""
     if not isinstance(params, dict):
         return params
-    return {
-        k: "[REDACTED]" if k.lower() in _SENSITIVE_PARAM_KEYS else v
-        for k, v in params.items()
-    }
+    result: dict[str, Any] = {}
+    for k, v in params.items():
+        if k.lower() in _SENSITIVE_PARAM_KEYS:
+            result[k] = "[REDACTED]"
+        elif k == "arguments" and isinstance(v, dict):
+            result[k] = _sanitize_params(v)
+        else:
+            result[k] = v
+    return result
 
 
 class LoggingMiddleware(Middleware):
@@ -204,8 +209,17 @@ class LoggingMiddleware(Middleware):
     Tool calls are handled in on_call_tool() which wraps execution to capture
     duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled
     in on_message().
+
+    When tool search is enabled (progressive discovery), the MCP client calls
+    ``call_tool`` proxies instead of individual tools.  This middleware 
resolves
+    the underlying tool name from ``call_tool`` arguments so that analytics
+    queries can filter by the actual tool (stored as ``mcp_tool`` in the 
curated
+    payload).
     """
 
+    #: Proxy name used by FastMCP tool-search transforms.
+    _CALL_TOOL_PROXY = "call_tool"
+
     def _is_error_response(self, result: ToolResult) -> bool:
         """Check if a tool result contains an error schema response.
 
@@ -244,6 +258,28 @@ class LoggingMiddleware(Middleware):
             dataset_id = params.get("dataset_id")
         return agent_id, user_id, dashboard_id, slice_id, dataset_id, params
 
+    @staticmethod
+    def _resolve_tool_name(tool_name: str | None, params: Any) -> str | None:
+        """Resolve the underlying tool name from call_tool proxy arguments.
+
+        When tool search is enabled, the MCP client uses the ``call_tool``
+        proxy and passes the real tool name as the ``name`` argument.  This
+        helper extracts that value so we can log which tool was actually
+        executed rather than just ``"call_tool"``.
+
+        Returns:
+            The resolved tool name if *tool_name* is the call_tool proxy and
+            ``params["name"]`` is a non-empty string, otherwise ``None``.
+        """
+        if (
+            tool_name == LoggingMiddleware._CALL_TOOL_PROXY
+            and isinstance(params, dict)
+            and isinstance(params.get("name"), str)
+            and params["name"]
+        ):
+            return params["name"]
+        return None
+
     async def on_call_tool(
         self,
         context: MiddlewareContext,
@@ -254,11 +290,13 @@ class LoggingMiddleware(Middleware):
             self._extract_context_info(context)
         )
         tool_name = getattr(context.message, "name", None)
+        mcp_tool = self._resolve_tool_name(tool_name, params)
 
         mcp_call_id = secrets.token_hex(16)
         _mcp_call_id_var.set(mcp_call_id)
         start_time = time.time()
         success = False
+        error_type: str | None = None
         try:
             result = await call_next(context)
             success = not self._is_error_response(result)
@@ -270,11 +308,27 @@ class LoggingMiddleware(Middleware):
                     structured_content=result.structured_content,
                 )
             return result
-        except Exception:
+        except Exception as exc:
+            error_type = type(exc).__name__
             success = False
             raise
         finally:
             duration_ms = int((time.time() - start_time) * 1000)
+            payload: dict[str, Any] = {
+                "mcp_call_id": mcp_call_id,
+                "tool": tool_name,
+                "agent_id": agent_id,
+                "params": _sanitize_params(params),
+                "method": context.method,
+                "dashboard_id": dashboard_id,
+                "slice_id": slice_id,
+                "dataset_id": dataset_id,
+                "success": success,
+            }
+            if mcp_tool is not None:
+                payload["mcp_tool"] = mcp_tool
+            if error_type is not None:
+                payload["error_type"] = error_type
             if has_app_context():
                 event_logger.log(
                     user_id=user_id,
@@ -283,22 +337,18 @@ class LoggingMiddleware(Middleware):
                     duration_ms=duration_ms,
                     slice_id=slice_id,
                     referrer=None,
-                    curated_payload={
-                        "mcp_call_id": mcp_call_id,
-                        "tool": tool_name,
-                        "agent_id": agent_id,
-                        "params": _sanitize_params(params),
-                        "method": context.method,
-                        "dashboard_id": dashboard_id,
-                        "slice_id": slice_id,
-                        "dataset_id": dataset_id,
-                        "success": success,
-                    },
+                    curated_payload=payload,
                 )
+            extra_parts = []
+            if mcp_tool is not None:
+                extra_parts.append(f"mcp_tool={mcp_tool}")
+            if error_type is not None:
+                extra_parts.append(f"error_type={error_type}")
+            extra = (", " + ", ".join(extra_parts)) if extra_parts else ""
             logger.info(
                 "MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, "
                 "dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, "
-                "success=%s, mcp_call_id=%s",
+                "success=%s, mcp_call_id=%s%s",
                 tool_name,
                 agent_id,
                 user_id,
@@ -309,6 +359,7 @@ class LoggingMiddleware(Middleware):
                 duration_ms,
                 success,
                 mcp_call_id,
+                extra,
             )
 
     async def on_message(
diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py 
b/tests/unit_tests/mcp_service/test_middleware_logging.py
index 3f81dc3ae9b..de8f246b599 100644
--- a/tests/unit_tests/mcp_service/test_middleware_logging.py
+++ b/tests/unit_tests/mcp_service/test_middleware_logging.py
@@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and 
on_message() methods.
 
 Tests verify that:
 - on_call_tool() captures duration_ms and success status
+- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool)
+- on_call_tool() captures error_type on failure
 - on_message() logs non-tool messages without duration
 - _extract_context_info() extracts entity IDs from params
 """
@@ -65,7 +67,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_logs_duration_and_success(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool records duration_ms and success=True on normal 
return."""
         middleware = LoggingMiddleware()
         ctx = _make_context(name="list_charts")
@@ -91,8 +93,8 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_logs_failure_on_exception(
         self, mock_get_user_id, mock_event_logger
-    ):
-        """on_call_tool records success=False when tool raises."""
+    ) -> None:
+        """on_call_tool records success=False and error_type when tool 
raises."""
         middleware = LoggingMiddleware()
         ctx = _make_context(name="execute_sql")
         call_next = AsyncMock(side_effect=ValueError("boom"))
@@ -104,6 +106,7 @@ class TestLoggingMiddlewareOnCallTool:
         mock_event_logger.log.assert_called_once()
         call_kwargs = mock_event_logger.log.call_args[1]
         assert call_kwargs["curated_payload"]["success"] is False
+        assert call_kwargs["curated_payload"]["error_type"] == "ValueError"
         assert call_kwargs["duration_ms"] >= 0
 
     @patch("superset.mcp_service.middleware.event_logger")
@@ -111,7 +114,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_logs_failure_on_tool_error(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool records success=False when GlobalErrorHandler raises 
ToolError.
 
         This simulates the real middleware chain: GlobalErrorHandler catches
@@ -137,7 +140,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_includes_mcp_call_id_in_curated_payload(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool adds mcp_call_id to curated_payload."""
         middleware = LoggingMiddleware()
         ctx = _make_context(name="list_charts")
@@ -155,7 +158,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_injects_mcp_call_id_into_tool_result_meta(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool injects mcp_call_id into ToolResult.meta."""
         middleware = LoggingMiddleware()
         ctx = _make_context(name="list_charts")
@@ -173,7 +176,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_preserves_existing_meta(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool merges mcp_call_id with existing ToolResult.meta."""
         middleware = LoggingMiddleware()
         ctx = _make_context(name="list_charts")
@@ -193,7 +196,7 @@ class TestLoggingMiddlewareOnCallTool:
     @pytest.mark.asyncio
     async def test_on_call_tool_extracts_entity_ids(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool extracts dashboard_id, chart_id, dataset_id from 
params."""
         middleware = LoggingMiddleware()
         ctx = _make_context(
@@ -222,7 +225,7 @@ class TestLoggingMiddlewareOnMessage:
     @pytest.mark.asyncio
     async def test_on_message_logs_without_duration(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_message logs with action=mcp_message and duration_ms=None."""
         middleware = LoggingMiddleware()
         ctx = _make_context(method="resources/read", name="instance/metadata")
@@ -240,12 +243,124 @@ class TestLoggingMiddlewareOnMessage:
         # on_message should NOT have success field
         assert "success" not in call_kwargs["curated_payload"]
 
+    @patch("superset.mcp_service.middleware.event_logger")
+    @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+    @pytest.mark.asyncio
+    async def test_on_call_tool_no_error_type_on_success(
+        self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+    ) -> None:
+        """on_call_tool omits error_type from payload on success."""
+        middleware = LoggingMiddleware()
+        ctx = _make_context(name="list_charts")
+        call_next = AsyncMock(return_value="ok")
+
+        await middleware.on_call_tool(ctx, call_next)
+
+        payload = mock_event_logger.log.call_args[1]["curated_payload"]
+        assert "error_type" not in payload
+
+    @patch("superset.mcp_service.middleware.event_logger")
+    @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+    @pytest.mark.asyncio
+    async def test_on_call_tool_resolves_call_tool_proxy(
+        self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+    ) -> None:
+        """call_tool proxy is resolved to the actual tool name via mcp_tool."""
+        middleware = LoggingMiddleware()
+        ctx = _make_context(
+            name="call_tool",
+            params={"name": "list_datasets", "arguments": {"page": 1}},
+        )
+        call_next = AsyncMock(return_value="datasets")
+
+        await middleware.on_call_tool(ctx, call_next)
+
+        payload = mock_event_logger.log.call_args[1]["curated_payload"]
+        assert payload["tool"] == "call_tool"
+        assert payload["mcp_tool"] == "list_datasets"
+
+    @patch("superset.mcp_service.middleware.event_logger")
+    @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+    @pytest.mark.asyncio
+    async def test_on_call_tool_no_mcp_tool_for_direct_calls(
+        self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+    ) -> None:
+        """Direct tool calls (not via proxy) omit mcp_tool from payload."""
+        middleware = LoggingMiddleware()
+        ctx = _make_context(name="list_charts")
+        call_next = AsyncMock(return_value="charts")
+
+        await middleware.on_call_tool(ctx, call_next)
+
+        payload = mock_event_logger.log.call_args[1]["curated_payload"]
+        assert payload["tool"] == "list_charts"
+        assert "mcp_tool" not in payload
+
+    @patch("superset.mcp_service.middleware.event_logger")
+    @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+    @pytest.mark.asyncio
+    async def test_on_call_tool_proxy_failure_captures_both_fields(
+        self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+    ) -> None:
+        """call_tool proxy failure captures mcp_tool and error_type."""
+        middleware = LoggingMiddleware()
+        ctx = _make_context(
+            name="call_tool",
+            params={"name": "get_chart_data", "arguments": {"chart_id": 1}},
+        )
+        call_next = AsyncMock(side_effect=PermissionError("access denied"))
+
+        with pytest.raises(PermissionError):
+            await middleware.on_call_tool(ctx, call_next)
+
+        payload = mock_event_logger.log.call_args[1]["curated_payload"]
+        assert payload["tool"] == "call_tool"
+        assert payload["mcp_tool"] == "get_chart_data"
+        assert payload["success"] is False
+        assert payload["error_type"] == "PermissionError"
+
+
+class TestResolveToolName:
+    """Tests for LoggingMiddleware._resolve_tool_name()."""
+
+    def test_resolves_call_tool_proxy(self) -> None:
+        """Returns the real tool name when call_tool proxy is used."""
+        assert (
+            LoggingMiddleware._resolve_tool_name(
+                "call_tool", {"name": "list_datasets", "arguments": {}}
+            )
+            == "list_datasets"
+        )
+
+    def test_returns_none_for_direct_tool(self) -> None:
+        """Returns None for direct tool calls (not via proxy)."""
+        assert LoggingMiddleware._resolve_tool_name("list_charts", {"page": 
1}) is None
+
+    def test_returns_none_when_name_missing(self) -> None:
+        """Returns None when call_tool params lack 'name'."""
+        assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo": 
"bar"}) is None
+
+    def test_returns_none_for_empty_name(self) -> None:
+        """Returns None when call_tool params have empty 'name'."""
+        assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""}) 
is None
+
+    def test_returns_none_for_non_string_name(self) -> None:
+        """Returns None when call_tool name param is not a string."""
+        assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": 
123}) is None
+
+    def test_returns_none_for_search_tools(self) -> None:
+        """search_tools proxy is not resolved (no underlying tool name)."""
+        assert (
+            LoggingMiddleware._resolve_tool_name("search_tools", {"query": 
"datasets"})
+            is None
+        )
+
 
 class TestExtractContextInfo:
     """Tests for LoggingMiddleware._extract_context_info()."""
 
     @patch("superset.mcp_service.middleware.get_user_id", return_value=99)
-    def test_extract_with_metadata_agent_id(self, mock_get_user_id):
+    def test_extract_with_metadata_agent_id(self, mock_get_user_id) -> None:
         """Extracts agent_id from context.metadata."""
         middleware = LoggingMiddleware()
         ctx = _make_context(metadata={"agent_id": "agent-123"})
@@ -261,7 +376,7 @@ class TestExtractContextInfo:
         "superset.mcp_service.middleware.get_user_id",
         side_effect=RuntimeError("no Flask request context"),
     )
-    def test_extract_handles_missing_user(self, mock_get_user_id):
+    def test_extract_handles_missing_user(self, mock_get_user_id) -> None:
         """Gracefully handles missing user context."""
         middleware = LoggingMiddleware()
         ctx = _make_context()
@@ -273,7 +388,7 @@ class TestExtractContextInfo:
         assert user_id is None
 
     @patch("superset.mcp_service.middleware.get_user_id", return_value=1)
-    def test_extract_slice_id_from_chart_id(self, mock_get_user_id):
+    def test_extract_slice_id_from_chart_id(self, mock_get_user_id) -> None:
         """Extracts slice_id from chart_id param (alias)."""
         middleware = LoggingMiddleware()
         ctx = _make_context(params={"chart_id": 55})
@@ -283,7 +398,7 @@ class TestExtractContextInfo:
         assert slice_id == 55
 
     @patch("superset.mcp_service.middleware.get_user_id", return_value=1)
-    def test_extract_slice_id_from_slice_id(self, mock_get_user_id):
+    def test_extract_slice_id_from_slice_id(self, mock_get_user_id) -> None:
         """Extracts slice_id from slice_id param (fallback)."""
         middleware = LoggingMiddleware()
         ctx = _make_context(params={"slice_id": 66})
@@ -296,7 +411,7 @@ class TestExtractContextInfo:
 class TestIsErrorResponse:
     """Tests for LoggingMiddleware._is_error_response()."""
 
-    def test_detects_error_schema_response(self):
+    def test_detects_error_schema_response(self) -> None:
         """Detects ToolResult containing a serialized error schema
         (ChartError, DashboardError, etc.) via "error_type" field."""
         middleware = LoggingMiddleware()
@@ -308,7 +423,7 @@ class TestIsErrorResponse:
         result = ToolResult(content=[mt.TextContent(type="text", 
text=error_json)])
         assert middleware._is_error_response(result) is True
 
-    def test_success_response_not_detected_as_error(self):
+    def test_success_response_not_detected_as_error(self) -> None:
         """Normal ToolResult is not detected as error."""
         middleware = LoggingMiddleware()
         result = ToolResult(
@@ -316,7 +431,7 @@ class TestIsErrorResponse:
         )
         assert middleware._is_error_response(result) is False
 
-    def test_empty_content_not_detected_as_error(self):
+    def test_empty_content_not_detected_as_error(self) -> None:
         """ToolResult with empty content is not detected as error."""
         middleware = LoggingMiddleware()
         assert middleware._is_error_response(ToolResult(content=[])) is False
@@ -326,7 +441,7 @@ class TestIsErrorResponse:
     @pytest.mark.asyncio
     async def test_on_call_tool_logs_failure_for_error_schema(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """on_call_tool logs success=False when tool returns an
         error schema (e.g. ChartError)."""
         middleware = LoggingMiddleware()
@@ -366,7 +481,7 @@ class TestMiddlewareChainOrder:
     @pytest.mark.asyncio
     async def test_real_middleware_chain_logs_exception_as_failure(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """Tool exception is logged as success=False through the
         real middleware chain from build_middleware_list()."""
         from superset.mcp_service.server import build_middleware_list
@@ -413,7 +528,7 @@ class TestMiddlewareChainOrder:
     @pytest.mark.asyncio
     async def test_real_middleware_chain_error_result_has_mcp_call_id(
         self, mock_get_user_id, mock_event_logger
-    ):
+    ) -> None:
         """When a tool raises, the error ToolResult from
         StructuredContentStripper still carries mcp_call_id in meta."""
         from superset.mcp_service.server import build_middleware_list

Reply via email to