This is an automated email from the ASF dual-hosted git repository.

jli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 48220fb33f8 feat(mcp): add save_sql_query tool for SQL Lab saved 
queries (#38414)
48220fb33f8 is described below

commit 48220fb33f82d1a8bc5e1356b0949e65a16fb2f1
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Fri Mar 13 22:02:04 2026 +0100

    feat(mcp): add save_sql_query tool for SQL Lab saved queries (#38414)
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 superset/mcp_service/app.py                        |   5 +-
 superset/mcp_service/sql_lab/schemas.py            |  59 +++
 superset/mcp_service/sql_lab/tool/__init__.py      |   2 +
 .../mcp_service/sql_lab/tool/save_sql_query.py     | 137 ++++++
 .../sql_lab/tool/test_save_sql_query.py            | 467 +++++++++++++++++++++
 5 files changed, 669 insertions(+), 1 deletion(-)

diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py
index 409c7bb2e24..5f7ec6c64df 100644
--- a/superset/mcp_service/app.py
+++ b/superset/mcp_service/app.py
@@ -70,6 +70,7 @@ Chart Management:
 
 SQL Lab Integration:
 - execute_sql: Execute SQL queries and get results (requires database_id)
+- save_sql_query: Save a SQL query to Saved Queries list
 - open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
 
 Schema Discovery:
@@ -105,7 +106,8 @@ To find your own charts/dashboards:
 To explore data with SQL:
 1. list_datasets -> find a dataset and note its database_id
 2. execute_sql(database_id, sql) -> run query
-3. open_sql_lab_with_context(database_id) -> open SQL Lab UI
+3. save_sql_query(database_id, label, sql) -> save query for later reuse
+4. open_sql_lab_with_context(database_id) -> open SQL Lab UI
 
 generate_explore_link vs generate_chart:
 - Use generate_explore_link for exploration (no permanent chart created)
@@ -415,6 +417,7 @@ from superset.mcp_service.explore.tool import (  # noqa: 
F401, E402
 from superset.mcp_service.sql_lab.tool import (  # noqa: F401, E402
     execute_sql,
     open_sql_lab_with_context,
+    save_sql_query,
 )
 from superset.mcp_service.system import (  # noqa: F401, E402
     prompts as system_prompts,
diff --git a/superset/mcp_service/sql_lab/schemas.py 
b/superset/mcp_service/sql_lab/schemas.py
index e55ca53130f..2e0268774cc 100644
--- a/superset/mcp_service/sql_lab/schemas.py
+++ b/superset/mcp_service/sql_lab/schemas.py
@@ -147,6 +147,65 @@ class ExecuteSqlResponse(BaseModel):
     )
 
 
+class SaveSqlQueryRequest(BaseModel):
+    """Request schema for saving a SQL query."""
+
+    database_id: int = Field(
+        ..., description="Database connection ID the query runs against"
+    )
+    label: str = Field(
+        ...,
+        description="Name for the saved query (shown in Saved Queries list)",
+        min_length=1,
+        max_length=256,
+    )
+    sql: str = Field(
+        ...,
+        description="SQL query text to save",
+    )
+    schema_name: str | None = Field(
+        None,
+        description="Schema the query targets",
+        alias="schema",
+    )
+    catalog: str | None = Field(None, description="Catalog name (if 
applicable)")
+    description: str | None = Field(
+        None, description="Optional description of the query"
+    )
+
+    @field_validator("sql")
+    @classmethod
+    def sql_not_empty(cls, v: str) -> str:
+        if not v or not v.strip():
+            raise ValueError("SQL query cannot be empty")
+        return v.strip()
+
+    @field_validator("label")
+    @classmethod
+    def label_not_empty(cls, v: str) -> str:
+        if not v or not v.strip():
+            raise ValueError("Label cannot be empty")
+        return v.strip()
+
+
+class SaveSqlQueryResponse(BaseModel):
+    """Response schema for a saved SQL query."""
+
+    id: int = Field(..., description="Saved query ID")
+    label: str = Field(..., description="Query name")
+    sql: str = Field(..., description="SQL query text")
+    database_id: int = Field(..., description="Database ID")
+    schema_name: str | None = Field(None, description="Schema name", 
alias="schema")
+    catalog: str | None = Field(None, description="Catalog name (if 
applicable)")
+    description: str | None = Field(None, description="Query description")
+    url: str = Field(
+        ...,
+        description=(
+            "URL to open this saved query in SQL Lab (e.g., 
/sqllab?savedQueryId=42)"
+        ),
+    )
+
+
 class OpenSqlLabRequest(BaseModel):
     """Request schema for opening SQL Lab with context."""
 
diff --git a/superset/mcp_service/sql_lab/tool/__init__.py 
b/superset/mcp_service/sql_lab/tool/__init__.py
index 0fc7a0dd89f..d4f75d99610 100644
--- a/superset/mcp_service/sql_lab/tool/__init__.py
+++ b/superset/mcp_service/sql_lab/tool/__init__.py
@@ -23,8 +23,10 @@ from superset.mcp_service.sql_lab.tool.execute_sql import 
execute_sql
 from superset.mcp_service.sql_lab.tool.open_sql_lab_with_context import (
     open_sql_lab_with_context,
 )
+from superset.mcp_service.sql_lab.tool.save_sql_query import save_sql_query
 
 __all__ = [
     "execute_sql",
     "open_sql_lab_with_context",
+    "save_sql_query",
 ]
diff --git a/superset/mcp_service/sql_lab/tool/save_sql_query.py 
b/superset/mcp_service/sql_lab/tool/save_sql_query.py
new file mode 100644
index 00000000000..f9777793085
--- /dev/null
+++ b/superset/mcp_service/sql_lab/tool/save_sql_query.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Save SQL Query MCP Tool
+
+Tool for saving a SQL query as a named SavedQuery in Superset,
+so it appears in SQL Lab's "Saved Queries" list and can be
+reloaded/shared via URL.
+"""
+
+from __future__ import annotations
+
+import logging
+
+from fastmcp import Context
+from sqlalchemy.exc import SQLAlchemyError
+from superset_core.mcp.decorators import tool
+
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetErrorException, 
SupersetSecurityException
+from superset.extensions import event_logger
+from superset.mcp_service.sql_lab.schemas import (
+    SaveSqlQueryRequest,
+    SaveSqlQueryResponse,
+)
+from superset.mcp_service.utils.schema_utils import parse_request
+
+logger = logging.getLogger(__name__)
+
+
+@tool(tags=["mutate"])
+@parse_request(SaveSqlQueryRequest)
+async def save_sql_query(
+    request: SaveSqlQueryRequest, ctx: Context
+) -> SaveSqlQueryResponse:
+    """Save a SQL query so it appears in SQL Lab's Saved Queries list.
+
+    Creates a persistent SavedQuery that the user can reload from
+    SQL Lab, share via URL, and find in the Saved Queries page.
+    Requires a database_id, a label (name), and the SQL text.
+    """
+    await ctx.info(
+        "Saving SQL query: database_id=%s, label=%r"
+        % (request.database_id, request.label)
+    )
+
+    try:
+        from flask import g
+
+        from superset import db, security_manager
+        from superset.daos.query import SavedQueryDAO
+        from superset.mcp_service.utils.url_utils import get_superset_base_url
+        from superset.models.core import Database
+
+        # 1. Validate database exists and user has access
+        with 
event_logger.log_context(action="mcp.save_sql_query.db_validation"):
+            database = (
+                
db.session.query(Database).filter_by(id=request.database_id).first()
+            )
+            if not database:
+                raise SupersetErrorException(
+                    SupersetError(
+                        message=(f"Database with ID {request.database_id} not 
found"),
+                        error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
+                        level=ErrorLevel.ERROR,
+                    )
+                )
+
+            if not security_manager.can_access_database(database):
+                raise SupersetSecurityException(
+                    SupersetError(
+                        message=(f"Access denied to database 
{database.database_name}"),
+                        
error_type=(SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR),
+                        level=ErrorLevel.ERROR,
+                    )
+                )
+
+        # 2. Create the saved query
+        with event_logger.log_context(action="mcp.save_sql_query.create"):
+            saved_query = SavedQueryDAO.create(
+                attributes={
+                    "user_id": g.user.id,
+                    "db_id": request.database_id,
+                    "label": request.label,
+                    "sql": request.sql,
+                    "schema": request.schema_name or "",
+                    "catalog": request.catalog,
+                    "description": request.description or "",
+                }
+            )
+            db.session.commit()  # pylint: disable=consider-using-transaction
+
+        # 3. Build response
+        base_url = get_superset_base_url()
+        saved_query_url = f"{base_url}/sqllab?savedQueryId={saved_query.id}"
+
+        await ctx.info(
+            "Saved query created: id=%s, url=%s" % (saved_query.id, 
saved_query_url)
+        )
+
+        return SaveSqlQueryResponse(
+            id=saved_query.id,
+            label=saved_query.label,
+            sql=saved_query.sql,
+            database_id=request.database_id,
+            schema_name=request.schema_name,
+            catalog=getattr(saved_query, "catalog", None),
+            description=request.description,
+            url=saved_query_url,
+        )
+
+    except (SupersetErrorException, SupersetSecurityException):
+        raise
+    except SQLAlchemyError as e:
+        from superset import db
+
+        db.session.rollback()
+        await ctx.error(
+            "Failed to save SQL query: error=%s, database_id=%s"
+            % (str(e), request.database_id)
+        )
+        raise
diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py 
b/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py
new file mode 100644
index 00000000000..469ca9fd43c
--- /dev/null
+++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py
@@ -0,0 +1,467 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Unit tests for save_sql_query MCP tool schemas and logic.
+"""
+
+import importlib
+import sys
+import types
+from unittest.mock import MagicMock, Mock, patch
+
+import pytest
+from pydantic import ValidationError
+
+from superset.mcp_service.sql_lab.schemas import (
+    SaveSqlQueryRequest,
+    SaveSqlQueryResponse,
+)
+
+
+class TestSaveSqlQueryRequest:
+    """Test SaveSqlQueryRequest schema validation."""
+
+    def test_valid_request(self) -> None:
+        req = SaveSqlQueryRequest(
+            database_id=1,
+            label="Revenue Query",
+            sql="SELECT SUM(revenue) FROM sales",
+        )
+        assert req.database_id == 1
+        assert req.label == "Revenue Query"
+        assert req.sql == "SELECT SUM(revenue) FROM sales"
+
+    def test_with_optional_fields(self) -> None:
+        req = SaveSqlQueryRequest(
+            database_id=1,
+            label="Revenue Query",
+            sql="SELECT 1",
+            schema="public",
+            catalog="main",
+            description="Sums revenue",
+        )
+        assert req.schema_name == "public"
+        assert req.catalog == "main"
+        assert req.description == "Sums revenue"
+
+    def test_empty_sql_fails(self) -> None:
+        with pytest.raises(ValidationError, match="SQL query cannot be empty"):
+            SaveSqlQueryRequest(database_id=1, label="test", sql="  ")
+
+    def test_empty_label_fails(self) -> None:
+        with pytest.raises(ValidationError, match="Label cannot be empty"):
+            SaveSqlQueryRequest(database_id=1, label="  ", sql="SELECT 1")
+
+    def test_sql_is_stripped(self) -> None:
+        req = SaveSqlQueryRequest(database_id=1, label="test", sql="  SELECT 1 
 ")
+        assert req.sql == "SELECT 1"
+
+    def test_label_is_stripped(self) -> None:
+        req = SaveSqlQueryRequest(database_id=1, label="  My Query  ", 
sql="SELECT 1")
+        assert req.label == "My Query"
+
+    def test_label_max_length(self) -> None:
+        with pytest.raises(ValidationError, match="String should have at most 
256"):
+            SaveSqlQueryRequest(database_id=1, label="a" * 257, sql="SELECT 1")
+
+    def test_schema_alias(self) -> None:
+        """The field accepts 'schema' as alias for 'schema_name'."""
+        req = SaveSqlQueryRequest(
+            database_id=1,
+            label="test",
+            sql="SELECT 1",
+            schema="public",
+        )
+        assert req.schema_name == "public"
+
+
+class TestSaveSqlQueryResponse:
+    """Test SaveSqlQueryResponse schema."""
+
+    def test_response_fields(self) -> None:
+        resp = SaveSqlQueryResponse(
+            id=42,
+            label="Revenue",
+            sql="SELECT 1",
+            database_id=1,
+            url="/sqllab?savedQueryId=42",
+        )
+        assert resp.id == 42
+        assert resp.label == "Revenue"
+        assert resp.url == "/sqllab?savedQueryId=42"
+
+    def test_response_with_optional_fields(self) -> None:
+        resp = SaveSqlQueryResponse(
+            id=42,
+            label="Revenue",
+            sql="SELECT 1",
+            database_id=1,
+            schema="public",
+            description="A query",
+            url="/sqllab?savedQueryId=42",
+        )
+        assert resp.schema_name == "public"
+        assert resp.description == "A query"
+
+
+def _force_passthrough_decorators():
+    """Force superset_core MCP tool decorator to be a passthrough.
+
+    In CI, superset_core is fully installed and the real @tool decorator
+    includes authentication middleware. For unit tests we want to bypass
+    auth and test the tool logic directly, so we always replace the
+    decorator with a passthrough regardless of installation state.
+
+    Returns a dict of original sys.modules entries so they can be restored.
+    """
+
+    def _passthrough_tool(func=None, **kwargs):
+        if func is not None:
+            return func
+        return lambda f: f
+
+    mock_mcp = MagicMock()
+    mock_mcp.tool = _passthrough_tool
+
+    mock_decorators = MagicMock()
+    mock_decorators.tool = _passthrough_tool
+
+    mock_api = MagicMock()
+    mock_api.mcp = mock_mcp
+
+    # Save original modules so we can restore them later
+    saved_modules: dict[str, types.ModuleType] = {}
+
+    # Only mock the specific decorator submodules, NOT the top-level
+    # superset_core package. Replacing sys.modules["superset_core"] with
+    # a MagicMock causes 'superset_core' is not a package errors for
+    # other submodules (queries, common) that are imported by sibling
+    # tool files during test collection.
+    mock_keys = [
+        "superset_core.api",
+        "superset_core.api.mcp",
+        "superset_core.api.types",
+        "superset_core.mcp",
+        "superset_core.mcp.decorators",
+    ]
+    for key in mock_keys:
+        if key in sys.modules:
+            saved_modules[key] = sys.modules[key]
+
+    sys.modules["superset_core.api"] = mock_api
+    sys.modules["superset_core.api.mcp"] = mock_mcp
+    sys.modules["superset_core.mcp"] = mock_mcp
+    sys.modules["superset_core.mcp.decorators"] = mock_decorators
+    sys.modules.setdefault("superset_core.api.types", MagicMock())
+
+    return saved_modules
+
+
+def _restore_modules(saved_modules: dict[str, types.ModuleType]) -> None:
+    """Restore original sys.modules entries after passthrough mocking."""
+    # Remove mock entries for decorator paths and tool modules imported
+    # under patched decorators. Do NOT remove the top-level superset_core
+    # package or unrelated submodules (queries, common, etc.).
+    mock_prefixes = (
+        "superset_core.api",
+        "superset_core.mcp",
+        "superset.mcp_service.sql_lab.tool",
+    )
+    for key in list(sys.modules.keys()):
+        if any(key.startswith(prefix) for prefix in mock_prefixes):
+            del sys.modules[key]
+    # Restore originals (including any previously-imported tool modules)
+    sys.modules.update(saved_modules)
+
+
+def _get_tool_module():
+    """Import save_sql_query with passthrough decorators (no auth).
+
+    Returns (module, saved_modules) so callers can restore sys.modules.
+    """
+    saved_modules = _force_passthrough_decorators()
+    # Clear cached module imports so we get a fresh import with mocked
+    # decorators. This is necessary because in CI the real @tool decorator
+    # may have been applied during a previous import.
+    mod_name = "superset.mcp_service.sql_lab.tool.save_sql_query"
+    saved_tool_modules: dict[str, object] = {}
+    for key in list(sys.modules.keys()):
+        if key.startswith("superset.mcp_service.sql_lab.tool"):
+            saved_tool_modules[key] = sys.modules.pop(key)
+    saved_modules.update(saved_tool_modules)
+    mod = importlib.import_module(mod_name)
+    return mod, saved_modules
+
+
+def _make_mock_ctx():
+    """Create a mock FastMCP Context with awaitable methods."""
+
+    async def _noop(*args, **kwargs):
+        pass
+
+    ctx = MagicMock()
+    ctx.info = _noop
+    ctx.error = _noop
+    ctx.warning = _noop
+    return ctx
+
+
+class TestSaveSqlQueryToolLogic:
+    """Test save_sql_query tool internal logic.
+
+    The tool function uses lazy imports inside its body (from flask import g,
+    from superset import db, etc.). We patch at the import source so that
+    when the function runs, it picks up our mocks.
+
+    The @parse_request decorator injects ctx via get_context() and strips
+    __wrapped__, so we mock get_context and call the decorated function
+    directly (without unwrapping).
+    """
+
+    @pytest.mark.anyio
+    async def test_save_query_creates_saved_query(self) -> None:
+        """Verify the tool calls SavedQueryDAO.create with correct attrs."""
+        mod, saved = _get_tool_module()
+        try:
+            mock_ctx = _make_mock_ctx()
+
+            mock_db_obj = MagicMock()
+            mock_db_obj.id = 1
+            mock_db_obj.database_name = "test_db"
+
+            mock_sq = MagicMock()
+            mock_sq.id = 42
+            mock_sq.label = "Revenue Query"
+            mock_sq.sql = "SELECT SUM(revenue) FROM sales"
+            mock_sq.catalog = None
+
+            request = SaveSqlQueryRequest(
+                database_id=1,
+                label="Revenue Query",
+                sql="SELECT SUM(revenue) FROM sales",
+            )
+
+            mock_db_session = MagicMock()
+            (
+                
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
+            ) = mock_db_obj
+
+            mock_sm = MagicMock()
+            mock_sm.can_access_database.return_value = True
+
+            mock_dao = MagicMock()
+            mock_dao.create.return_value = mock_sq
+
+            mock_g = MagicMock()
+            mock_g.user = Mock(id=1)
+
+            mock_event_logger = MagicMock()
+            mock_event_logger.log_context.return_value.__enter__ = Mock()
+            mock_event_logger.log_context.return_value.__exit__ = Mock(
+                return_value=False
+            )
+
+            with (
+                patch(
+                    "fastmcp.server.dependencies.get_context",
+                    return_value=mock_ctx,
+                ),
+                patch("superset.db", mock_db_session),
+                patch("superset.security_manager", mock_sm),
+                patch("superset.daos.query.SavedQueryDAO", mock_dao),
+                patch(
+                    
"superset.mcp_service.utils.url_utils.get_superset_base_url",
+                    return_value="http://localhost:8088";,
+                ),
+                patch("flask.g", mock_g),
+                patch.object(mod, "event_logger", mock_event_logger),
+            ):
+                result = await mod.save_sql_query(request)
+
+                assert result.id == 42
+                assert result.label == "Revenue Query"
+                assert "savedQueryId=42" in result.url
+                mock_dao.create.assert_called_once()
+                call_attrs = mock_dao.create.call_args[1]["attributes"]
+                assert call_attrs["db_id"] == 1
+                assert call_attrs["label"] == "Revenue Query"
+                assert call_attrs["sql"] == "SELECT SUM(revenue) FROM sales"
+                assert call_attrs["user_id"] == 1
+                mock_db_session.session.commit.assert_called_once()
+        finally:
+            _restore_modules(saved)
+
+    @pytest.mark.anyio
+    async def test_save_query_database_not_found(self) -> None:
+        mod, saved = _get_tool_module()
+        try:
+            mock_ctx = _make_mock_ctx()
+
+            request = SaveSqlQueryRequest(
+                database_id=999,
+                label="Test",
+                sql="SELECT 1",
+            )
+
+            mock_db_session = MagicMock()
+            (
+                
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
+            ) = None
+
+            mock_g = MagicMock()
+            mock_g.user = Mock(id=1)
+
+            mock_event_logger = MagicMock()
+            mock_event_logger.log_context.return_value.__enter__ = Mock()
+            mock_event_logger.log_context.return_value.__exit__ = Mock(
+                return_value=False
+            )
+
+            with (
+                patch(
+                    "fastmcp.server.dependencies.get_context",
+                    return_value=mock_ctx,
+                ),
+                patch("superset.db", mock_db_session),
+                patch("flask.g", mock_g),
+                patch.object(mod, "event_logger", mock_event_logger),
+            ):
+                from superset.exceptions import SupersetErrorException
+
+                with pytest.raises(SupersetErrorException, match="not found"):
+                    await mod.save_sql_query(request)
+        finally:
+            _restore_modules(saved)
+
+    @pytest.mark.anyio
+    async def test_save_query_access_denied(self) -> None:
+        mod, saved = _get_tool_module()
+        try:
+            mock_ctx = _make_mock_ctx()
+
+            mock_db_obj = MagicMock()
+            mock_db_obj.id = 1
+            mock_db_obj.database_name = "test_db"
+
+            request = SaveSqlQueryRequest(
+                database_id=1,
+                label="Test",
+                sql="SELECT 1",
+            )
+
+            mock_db_session = MagicMock()
+            (
+                
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
+            ) = mock_db_obj
+
+            mock_sm = MagicMock()
+            mock_sm.can_access_database.return_value = False
+
+            mock_g = MagicMock()
+            mock_g.user = Mock(id=1)
+
+            mock_event_logger = MagicMock()
+            mock_event_logger.log_context.return_value.__enter__ = Mock()
+            mock_event_logger.log_context.return_value.__exit__ = Mock(
+                return_value=False
+            )
+
+            with (
+                patch(
+                    "fastmcp.server.dependencies.get_context",
+                    return_value=mock_ctx,
+                ),
+                patch("superset.db", mock_db_session),
+                patch("superset.security_manager", mock_sm),
+                patch("flask.g", mock_g),
+                patch.object(mod, "event_logger", mock_event_logger),
+            ):
+                from superset.exceptions import SupersetSecurityException
+
+                with pytest.raises(SupersetSecurityException, match="Access 
denied"):
+                    await mod.save_sql_query(request)
+        finally:
+            _restore_modules(saved)
+
+    @pytest.mark.anyio
+    async def test_save_query_with_schema_and_description(self) -> None:
+        mod, saved = _get_tool_module()
+        try:
+            mock_ctx = _make_mock_ctx()
+
+            mock_db_obj = MagicMock()
+            mock_db_obj.id = 1
+            mock_db_obj.database_name = "test_db"
+
+            mock_sq = MagicMock()
+            mock_sq.id = 10
+            mock_sq.label = "Test"
+            mock_sq.sql = "SELECT 1"
+            mock_sq.catalog = None
+
+            request = SaveSqlQueryRequest(
+                database_id=1,
+                label="Test",
+                sql="SELECT 1",
+                schema="public",
+                description="A test query",
+            )
+
+            mock_db_session = MagicMock()
+            (
+                
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
+            ) = mock_db_obj
+
+            mock_sm = MagicMock()
+            mock_sm.can_access_database.return_value = True
+
+            mock_dao = MagicMock()
+            mock_dao.create.return_value = mock_sq
+
+            mock_g = MagicMock()
+            mock_g.user = Mock(id=1)
+
+            mock_event_logger = MagicMock()
+            mock_event_logger.log_context.return_value.__enter__ = Mock()
+            mock_event_logger.log_context.return_value.__exit__ = Mock(
+                return_value=False
+            )
+
+            with (
+                patch(
+                    "fastmcp.server.dependencies.get_context",
+                    return_value=mock_ctx,
+                ),
+                patch("superset.db", mock_db_session),
+                patch("superset.security_manager", mock_sm),
+                patch("superset.daos.query.SavedQueryDAO", mock_dao),
+                patch(
+                    
"superset.mcp_service.utils.url_utils.get_superset_base_url",
+                    return_value="http://localhost:8088";,
+                ),
+                patch("flask.g", mock_g),
+                patch.object(mod, "event_logger", mock_event_logger),
+            ):
+                result = await mod.save_sql_query(request)
+
+                assert result.id == 10
+                call_attrs = mock_dao.create.call_args[1]["attributes"]
+                assert call_attrs["schema"] == "public"
+                assert call_attrs["description"] == "A test query"
+        finally:
+            _restore_modules(saved)

Reply via email to