This is an automated email from the ASF dual-hosted git repository.
aminghadersohi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new b6f545e61e1 feat(mcp): resolve call_tool proxy name and capture
error_type in logging (#38915)
b6f545e61e1 is described below
commit b6f545e61e1fbbc943e9dd5cec491c69a37a13c7
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue May 26 11:37:37 2026 -0700
feat(mcp): resolve call_tool proxy name and capture error_type in logging
(#38915)
Co-authored-by: Amin Ghadersohi <[email protected]>
---
superset/mcp_service/middleware.py | 85 +++++++++---
.../mcp_service/test_middleware_logging.py | 153 ++++++++++++++++++---
2 files changed, 202 insertions(+), 36 deletions(-)
diff --git a/superset/mcp_service/middleware.py
b/superset/mcp_service/middleware.py
index 844b5ff59d0..685022c3a9f 100644
--- a/superset/mcp_service/middleware.py
+++ b/superset/mcp_service/middleware.py
@@ -188,10 +188,15 @@ def _sanitize_params(params: dict[str, Any]) -> dict[str,
Any]:
"""Remove sensitive fields from params before logging."""
if not isinstance(params, dict):
return params
- return {
- k: "[REDACTED]" if k.lower() in _SENSITIVE_PARAM_KEYS else v
- for k, v in params.items()
- }
+ result: dict[str, Any] = {}
+ for k, v in params.items():
+ if k.lower() in _SENSITIVE_PARAM_KEYS:
+ result[k] = "[REDACTED]"
+ elif k == "arguments" and isinstance(v, dict):
+ result[k] = _sanitize_params(v)
+ else:
+ result[k] = v
+ return result
class LoggingMiddleware(Middleware):
@@ -204,8 +209,17 @@ class LoggingMiddleware(Middleware):
Tool calls are handled in on_call_tool() which wraps execution to capture
duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled
in on_message().
+
+ When tool search is enabled (progressive discovery), the MCP client calls
+ ``call_tool`` proxies instead of individual tools. This middleware
resolves
+ the underlying tool name from ``call_tool`` arguments so that analytics
+ queries can filter by the actual tool (stored as ``mcp_tool`` in the
curated
+ payload).
"""
+ #: Proxy name used by FastMCP tool-search transforms.
+ _CALL_TOOL_PROXY = "call_tool"
+
def _is_error_response(self, result: ToolResult) -> bool:
"""Check if a tool result contains an error schema response.
@@ -244,6 +258,28 @@ class LoggingMiddleware(Middleware):
dataset_id = params.get("dataset_id")
return agent_id, user_id, dashboard_id, slice_id, dataset_id, params
+ @staticmethod
+ def _resolve_tool_name(tool_name: str | None, params: Any) -> str | None:
+ """Resolve the underlying tool name from call_tool proxy arguments.
+
+ When tool search is enabled, the MCP client uses the ``call_tool``
+ proxy and passes the real tool name as the ``name`` argument. This
+ helper extracts that value so we can log which tool was actually
+ executed rather than just ``"call_tool"``.
+
+ Returns:
+ The resolved tool name if *tool_name* is the call_tool proxy and
+ ``params["name"]`` is a non-empty string, otherwise ``None``.
+ """
+ if (
+ tool_name == LoggingMiddleware._CALL_TOOL_PROXY
+ and isinstance(params, dict)
+ and isinstance(params.get("name"), str)
+ and params["name"]
+ ):
+ return params["name"]
+ return None
+
async def on_call_tool(
self,
context: MiddlewareContext,
@@ -254,11 +290,13 @@ class LoggingMiddleware(Middleware):
self._extract_context_info(context)
)
tool_name = getattr(context.message, "name", None)
+ mcp_tool = self._resolve_tool_name(tool_name, params)
mcp_call_id = secrets.token_hex(16)
_mcp_call_id_var.set(mcp_call_id)
start_time = time.time()
success = False
+ error_type: str | None = None
try:
result = await call_next(context)
success = not self._is_error_response(result)
@@ -270,11 +308,27 @@ class LoggingMiddleware(Middleware):
structured_content=result.structured_content,
)
return result
- except Exception:
+ except Exception as exc:
+ error_type = type(exc).__name__
success = False
raise
finally:
duration_ms = int((time.time() - start_time) * 1000)
+ payload: dict[str, Any] = {
+ "mcp_call_id": mcp_call_id,
+ "tool": tool_name,
+ "agent_id": agent_id,
+ "params": _sanitize_params(params),
+ "method": context.method,
+ "dashboard_id": dashboard_id,
+ "slice_id": slice_id,
+ "dataset_id": dataset_id,
+ "success": success,
+ }
+ if mcp_tool is not None:
+ payload["mcp_tool"] = mcp_tool
+ if error_type is not None:
+ payload["error_type"] = error_type
if has_app_context():
event_logger.log(
user_id=user_id,
@@ -283,22 +337,18 @@ class LoggingMiddleware(Middleware):
duration_ms=duration_ms,
slice_id=slice_id,
referrer=None,
- curated_payload={
- "mcp_call_id": mcp_call_id,
- "tool": tool_name,
- "agent_id": agent_id,
- "params": _sanitize_params(params),
- "method": context.method,
- "dashboard_id": dashboard_id,
- "slice_id": slice_id,
- "dataset_id": dataset_id,
- "success": success,
- },
+ curated_payload=payload,
)
+ extra_parts = []
+ if mcp_tool is not None:
+ extra_parts.append(f"mcp_tool={mcp_tool}")
+ if error_type is not None:
+ extra_parts.append(f"error_type={error_type}")
+ extra = (", " + ", ".join(extra_parts)) if extra_parts else ""
logger.info(
"MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, "
"dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, "
- "success=%s, mcp_call_id=%s",
+ "success=%s, mcp_call_id=%s%s",
tool_name,
agent_id,
user_id,
@@ -309,6 +359,7 @@ class LoggingMiddleware(Middleware):
duration_ms,
success,
mcp_call_id,
+ extra,
)
async def on_message(
diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py
b/tests/unit_tests/mcp_service/test_middleware_logging.py
index 3f81dc3ae9b..de8f246b599 100644
--- a/tests/unit_tests/mcp_service/test_middleware_logging.py
+++ b/tests/unit_tests/mcp_service/test_middleware_logging.py
@@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and
on_message() methods.
Tests verify that:
- on_call_tool() captures duration_ms and success status
+- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool)
+- on_call_tool() captures error_type on failure
- on_message() logs non-tool messages without duration
- _extract_context_info() extracts entity IDs from params
"""
@@ -65,7 +67,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_logs_duration_and_success(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool records duration_ms and success=True on normal
return."""
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
@@ -91,8 +93,8 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_logs_failure_on_exception(
self, mock_get_user_id, mock_event_logger
- ):
- """on_call_tool records success=False when tool raises."""
+ ) -> None:
+ """on_call_tool records success=False and error_type when tool
raises."""
middleware = LoggingMiddleware()
ctx = _make_context(name="execute_sql")
call_next = AsyncMock(side_effect=ValueError("boom"))
@@ -104,6 +106,7 @@ class TestLoggingMiddlewareOnCallTool:
mock_event_logger.log.assert_called_once()
call_kwargs = mock_event_logger.log.call_args[1]
assert call_kwargs["curated_payload"]["success"] is False
+ assert call_kwargs["curated_payload"]["error_type"] == "ValueError"
assert call_kwargs["duration_ms"] >= 0
@patch("superset.mcp_service.middleware.event_logger")
@@ -111,7 +114,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_logs_failure_on_tool_error(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool records success=False when GlobalErrorHandler raises
ToolError.
This simulates the real middleware chain: GlobalErrorHandler catches
@@ -137,7 +140,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_includes_mcp_call_id_in_curated_payload(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool adds mcp_call_id to curated_payload."""
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
@@ -155,7 +158,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_injects_mcp_call_id_into_tool_result_meta(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool injects mcp_call_id into ToolResult.meta."""
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
@@ -173,7 +176,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_preserves_existing_meta(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool merges mcp_call_id with existing ToolResult.meta."""
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
@@ -193,7 +196,7 @@ class TestLoggingMiddlewareOnCallTool:
@pytest.mark.asyncio
async def test_on_call_tool_extracts_entity_ids(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool extracts dashboard_id, chart_id, dataset_id from
params."""
middleware = LoggingMiddleware()
ctx = _make_context(
@@ -222,7 +225,7 @@ class TestLoggingMiddlewareOnMessage:
@pytest.mark.asyncio
async def test_on_message_logs_without_duration(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_message logs with action=mcp_message and duration_ms=None."""
middleware = LoggingMiddleware()
ctx = _make_context(method="resources/read", name="instance/metadata")
@@ -240,12 +243,124 @@ class TestLoggingMiddlewareOnMessage:
# on_message should NOT have success field
assert "success" not in call_kwargs["curated_payload"]
+ @patch("superset.mcp_service.middleware.event_logger")
+ @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+ @pytest.mark.asyncio
+ async def test_on_call_tool_no_error_type_on_success(
+ self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+ ) -> None:
+ """on_call_tool omits error_type from payload on success."""
+ middleware = LoggingMiddleware()
+ ctx = _make_context(name="list_charts")
+ call_next = AsyncMock(return_value="ok")
+
+ await middleware.on_call_tool(ctx, call_next)
+
+ payload = mock_event_logger.log.call_args[1]["curated_payload"]
+ assert "error_type" not in payload
+
+ @patch("superset.mcp_service.middleware.event_logger")
+ @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+ @pytest.mark.asyncio
+ async def test_on_call_tool_resolves_call_tool_proxy(
+ self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+ ) -> None:
+ """call_tool proxy is resolved to the actual tool name via mcp_tool."""
+ middleware = LoggingMiddleware()
+ ctx = _make_context(
+ name="call_tool",
+ params={"name": "list_datasets", "arguments": {"page": 1}},
+ )
+ call_next = AsyncMock(return_value="datasets")
+
+ await middleware.on_call_tool(ctx, call_next)
+
+ payload = mock_event_logger.log.call_args[1]["curated_payload"]
+ assert payload["tool"] == "call_tool"
+ assert payload["mcp_tool"] == "list_datasets"
+
+ @patch("superset.mcp_service.middleware.event_logger")
+ @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+ @pytest.mark.asyncio
+ async def test_on_call_tool_no_mcp_tool_for_direct_calls(
+ self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+ ) -> None:
+ """Direct tool calls (not via proxy) omit mcp_tool from payload."""
+ middleware = LoggingMiddleware()
+ ctx = _make_context(name="list_charts")
+ call_next = AsyncMock(return_value="charts")
+
+ await middleware.on_call_tool(ctx, call_next)
+
+ payload = mock_event_logger.log.call_args[1]["curated_payload"]
+ assert payload["tool"] == "list_charts"
+ assert "mcp_tool" not in payload
+
+ @patch("superset.mcp_service.middleware.event_logger")
+ @patch("superset.mcp_service.middleware.get_user_id", return_value=42)
+ @pytest.mark.asyncio
+ async def test_on_call_tool_proxy_failure_captures_both_fields(
+ self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock
+ ) -> None:
+ """call_tool proxy failure captures mcp_tool and error_type."""
+ middleware = LoggingMiddleware()
+ ctx = _make_context(
+ name="call_tool",
+ params={"name": "get_chart_data", "arguments": {"chart_id": 1}},
+ )
+ call_next = AsyncMock(side_effect=PermissionError("access denied"))
+
+ with pytest.raises(PermissionError):
+ await middleware.on_call_tool(ctx, call_next)
+
+ payload = mock_event_logger.log.call_args[1]["curated_payload"]
+ assert payload["tool"] == "call_tool"
+ assert payload["mcp_tool"] == "get_chart_data"
+ assert payload["success"] is False
+ assert payload["error_type"] == "PermissionError"
+
+
+class TestResolveToolName:
+ """Tests for LoggingMiddleware._resolve_tool_name()."""
+
+ def test_resolves_call_tool_proxy(self) -> None:
+ """Returns the real tool name when call_tool proxy is used."""
+ assert (
+ LoggingMiddleware._resolve_tool_name(
+ "call_tool", {"name": "list_datasets", "arguments": {}}
+ )
+ == "list_datasets"
+ )
+
+ def test_returns_none_for_direct_tool(self) -> None:
+ """Returns None for direct tool calls (not via proxy)."""
+ assert LoggingMiddleware._resolve_tool_name("list_charts", {"page":
1}) is None
+
+ def test_returns_none_when_name_missing(self) -> None:
+ """Returns None when call_tool params lack 'name'."""
+ assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo":
"bar"}) is None
+
+ def test_returns_none_for_empty_name(self) -> None:
+ """Returns None when call_tool params have empty 'name'."""
+ assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""})
is None
+
+ def test_returns_none_for_non_string_name(self) -> None:
+ """Returns None when call_tool name param is not a string."""
+ assert LoggingMiddleware._resolve_tool_name("call_tool", {"name":
123}) is None
+
+ def test_returns_none_for_search_tools(self) -> None:
+ """search_tools proxy is not resolved (no underlying tool name)."""
+ assert (
+ LoggingMiddleware._resolve_tool_name("search_tools", {"query":
"datasets"})
+ is None
+ )
+
class TestExtractContextInfo:
"""Tests for LoggingMiddleware._extract_context_info()."""
@patch("superset.mcp_service.middleware.get_user_id", return_value=99)
- def test_extract_with_metadata_agent_id(self, mock_get_user_id):
+ def test_extract_with_metadata_agent_id(self, mock_get_user_id) -> None:
"""Extracts agent_id from context.metadata."""
middleware = LoggingMiddleware()
ctx = _make_context(metadata={"agent_id": "agent-123"})
@@ -261,7 +376,7 @@ class TestExtractContextInfo:
"superset.mcp_service.middleware.get_user_id",
side_effect=RuntimeError("no Flask request context"),
)
- def test_extract_handles_missing_user(self, mock_get_user_id):
+ def test_extract_handles_missing_user(self, mock_get_user_id) -> None:
"""Gracefully handles missing user context."""
middleware = LoggingMiddleware()
ctx = _make_context()
@@ -273,7 +388,7 @@ class TestExtractContextInfo:
assert user_id is None
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
- def test_extract_slice_id_from_chart_id(self, mock_get_user_id):
+ def test_extract_slice_id_from_chart_id(self, mock_get_user_id) -> None:
"""Extracts slice_id from chart_id param (alias)."""
middleware = LoggingMiddleware()
ctx = _make_context(params={"chart_id": 55})
@@ -283,7 +398,7 @@ class TestExtractContextInfo:
assert slice_id == 55
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
- def test_extract_slice_id_from_slice_id(self, mock_get_user_id):
+ def test_extract_slice_id_from_slice_id(self, mock_get_user_id) -> None:
"""Extracts slice_id from slice_id param (fallback)."""
middleware = LoggingMiddleware()
ctx = _make_context(params={"slice_id": 66})
@@ -296,7 +411,7 @@ class TestExtractContextInfo:
class TestIsErrorResponse:
"""Tests for LoggingMiddleware._is_error_response()."""
- def test_detects_error_schema_response(self):
+ def test_detects_error_schema_response(self) -> None:
"""Detects ToolResult containing a serialized error schema
(ChartError, DashboardError, etc.) via "error_type" field."""
middleware = LoggingMiddleware()
@@ -308,7 +423,7 @@ class TestIsErrorResponse:
result = ToolResult(content=[mt.TextContent(type="text",
text=error_json)])
assert middleware._is_error_response(result) is True
- def test_success_response_not_detected_as_error(self):
+ def test_success_response_not_detected_as_error(self) -> None:
"""Normal ToolResult is not detected as error."""
middleware = LoggingMiddleware()
result = ToolResult(
@@ -316,7 +431,7 @@ class TestIsErrorResponse:
)
assert middleware._is_error_response(result) is False
- def test_empty_content_not_detected_as_error(self):
+ def test_empty_content_not_detected_as_error(self) -> None:
"""ToolResult with empty content is not detected as error."""
middleware = LoggingMiddleware()
assert middleware._is_error_response(ToolResult(content=[])) is False
@@ -326,7 +441,7 @@ class TestIsErrorResponse:
@pytest.mark.asyncio
async def test_on_call_tool_logs_failure_for_error_schema(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""on_call_tool logs success=False when tool returns an
error schema (e.g. ChartError)."""
middleware = LoggingMiddleware()
@@ -366,7 +481,7 @@ class TestMiddlewareChainOrder:
@pytest.mark.asyncio
async def test_real_middleware_chain_logs_exception_as_failure(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""Tool exception is logged as success=False through the
real middleware chain from build_middleware_list()."""
from superset.mcp_service.server import build_middleware_list
@@ -413,7 +528,7 @@ class TestMiddlewareChainOrder:
@pytest.mark.asyncio
async def test_real_middleware_chain_error_result_has_mcp_call_id(
self, mock_get_user_id, mock_event_logger
- ):
+ ) -> None:
"""When a tool raises, the error ToolResult from
StructuredContentStripper still carries mcp_call_id in meta."""
from superset.mcp_service.server import build_middleware_list