This is an automated email from the ASF dual-hosted git repository. kasiazjc pushed a commit to branch fix/column-select-saved-tab-label in repository https://gitbox.apache.org/repos/asf/superset.git
commit 552777a06aa99648ebf871b1d186694d4203bb04 Author: Amin Ghadersohi <[email protected]> AuthorDate: Thu May 14 23:25:16 2026 +0000 feat(mcp): add create_dataset tool to register physical tables as datasets Adds create_dataset MCP tool that wraps POST /api/v1/dataset/ so callers can register an existing physical table as a Superset dataset without manual UI interaction. Returns DatasetInfo (same shape as get_dataset_info) so the resulting dataset_id feeds directly into generate_chart. - CreateDatasetRequest schema (database_id, schema, table_name, owners?) - Tool file with typed error handling (exists/not-found/validation/internal) - Registered in dataset/tool/__init__.py and app.py DEFAULT_INSTRUCTIONS - Unit tests covering success, owners, error cases, and full DatasetInfo shape --- superset/mcp_service/app.py | 2 + superset/mcp_service/dataset/schemas.py | 26 ++ superset/mcp_service/dataset/tool/__init__.py | 2 + .../mcp_service/dataset/tool/create_dataset.py | 115 ++++++++ .../dataset/tool/test_create_dataset.py | 302 +++++++++++++++++++++ 5 files changed, 447 insertions(+) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 7ab6a7a774b..c8c99490a64 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -76,6 +76,7 @@ Database Connections: Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) +- create_dataset: Register an existing physical table as a dataset against a DB connection - create_virtual_dataset: Save a SQL query as a virtual dataset for charting - query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart @@ -544,6 +545,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402 list_databases, ) from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 + create_dataset, create_virtual_dataset, get_dataset_info, list_datasets, diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index b5aa475ff2e..25ec182c1a9 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -323,6 +323,32 @@ class GetDatasetInfoRequest(MetadataCacheControl): ] +class CreateDatasetRequest(BaseModel): + """Request schema for create_dataset to register a physical table as a dataset.""" + + database_id: Annotated[ + int, + Field( + description="ID of the database connection to register the table against" + ), + ] + schema: Annotated[ + str, + Field(description="Schema (namespace) where the table lives, e.g. 'public'"), + ] + table_name: Annotated[ + str, + Field(description="Name of the physical table to register as a dataset"), + ] + owners: Annotated[ + List[int] | None, + Field( + default=None, + description="Optional list of owner user IDs. Defaults to calling user.", + ), + ] + + class CreateVirtualDatasetRequest(BaseModel): """Request schema for create_virtual_dataset.""" diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py index cad8d4ed569..1481a20c76a 100644 --- a/superset/mcp_service/dataset/tool/__init__.py +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. +from .create_dataset import create_dataset from .create_virtual_dataset import create_virtual_dataset from .get_dataset_info import get_dataset_info from .list_datasets import list_datasets from .query_dataset import query_dataset __all__ = [ + "create_dataset", "create_virtual_dataset", "list_datasets", "get_dataset_info", diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py new file mode 100644 index 00000000000..7d0d9ce4fa3 --- /dev/null +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.mcp_service.dataset.schemas import ( + CreateDatasetRequest, + DatasetError, + DatasetInfo, + serialize_dataset_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["mutate"], + class_permission_name="Dataset", + method_permission_name="write", + annotations=ToolAnnotations( + title="Register a physical table as a dataset", + readOnlyHint=False, + destructiveHint=False, + ), +) +async def create_dataset( + request: CreateDatasetRequest, ctx: Context +) -> DatasetInfo | DatasetError: + """Register an existing physical table as a Superset dataset. + + Use this tool when the user wants to make a physical database table available + for charting or exploration. The table must already exist in the target database. + + Workflow: + 1. Call list_databases to find the correct database_id + 2. Call this tool with database_id, schema, and table_name + 3. Use the returned id as dataset_id in generate_chart or generate_explore_link + + Returns DatasetInfo on success or DatasetError with error_type on failure. + """ + await ctx.info( + "Registering physical table as dataset: database_id=%s, schema=%r, table=%r" + % (request.database_id, request.schema, request.table_name) + ) + + try: + from superset.commands.dataset.create import CreateDatasetCommand + from superset.commands.dataset.exceptions import ( + DatasetCreateFailedError, + DatasetExistsValidationError, + DatasetInvalidError, + TableNotFoundValidationError, + ) + + dataset_properties: dict[str, Any] = { + "database": request.database_id, + "schema": request.schema, + "table_name": request.table_name, + } + if request.owners is not None: + dataset_properties["owners"] = request.owners + + dataset = CreateDatasetCommand(dataset_properties).run() + + result = serialize_dataset_object(dataset) + if result is None: + return DatasetError.create( + error="Dataset was created but could not be serialized", + error_type="InternalError", + ) + + await ctx.info( + "Dataset registered: id=%s, table=%r" % (dataset.id, dataset.table_name) + ) + return result + + except DatasetExistsValidationError as exc: + await ctx.warning("Dataset already exists: %s" % str(exc)) + return DatasetError.create(error=str(exc), error_type="DatasetExistsError") + except TableNotFoundValidationError as exc: + await ctx.warning("Table not found: %s" % str(exc)) + return DatasetError.create(error=str(exc), error_type="TableNotFoundError") + except DatasetInvalidError as exc: + messages = exc.normalized_messages() + await ctx.warning("Dataset validation failed: %s" % (messages,)) + return DatasetError.create(error=str(messages), error_type="ValidationError") + except DatasetCreateFailedError as exc: + await ctx.error("Dataset creation failed: %s" % str(exc)) + return DatasetError.create(error=str(exc), error_type="CreateFailedError") + except Exception as exc: + await ctx.error( + "Unexpected error registering dataset: %s: %s" + % (type(exc).__name__, str(exc)) + ) + return DatasetError.create( + error=f"Failed to create dataset: {exc}", error_type="InternalError" + ) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py new file mode 100644 index 00000000000..b65a3296839 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for create_dataset MCP tool.""" + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def _make_mock_dataset( + dataset_id: int = 42, + table_name: str = "orders", + schema: str = "public", + database_name: str = "main_db", +) -> MagicMock: + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = table_name + dataset.schema = schema + dataset.description = None + dataset.certified_by = None + dataset.certification_details = None + dataset.changed_by = None + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = None + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = f"[{database_name}].[{schema}]" + dataset.database = MagicMock() + dataset.database.database_name = database_name + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = None + dataset.template_params = None + dataset.extra = None + dataset.uuid = f"dataset-uuid-{dataset_id}" + dataset.columns = [] + dataset.metrics = [] + return dataset + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestCreateDataset: + """Tests for the create_dataset MCP tool.""" + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_success(self, mock_command_class, mcp_server): + """Happy path: tool creates dataset and returns DatasetInfo.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 42 + assert data["table_name"] == "orders" + assert data["schema_name"] == "public" + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["database"] == 1 + assert call_kwargs["schema"] == "public" + assert call_kwargs["table_name"] == "orders" + assert "owners" not in call_kwargs + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_with_owners(self, mock_command_class, mcp_server): + """Owners list is forwarded to the command when supplied.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 2, + "schema": "sales", + "table_name": "transactions", + "owners": [5, 10], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["owners"] == [5, 10] + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_already_exists(self, mock_command_class, mcp_server): + """Returns DatasetError when a dataset for the table already exists.""" + from superset.commands.dataset.exceptions import DatasetExistsValidationError + from superset.sql.parse import Table + + mock_command = MagicMock() + mock_command.run.side_effect = DatasetExistsValidationError( + Table("orders", "public", None) + ) + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "DatasetExistsError" + assert "error" in data + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_table_not_found(self, mock_command_class, mcp_server): + """Returns DatasetError when the physical table does not exist in the DB.""" + from superset.commands.dataset.exceptions import TableNotFoundValidationError + from superset.sql.parse import Table + + mock_command = MagicMock() + mock_command.run.side_effect = TableNotFoundValidationError( + Table("missing_table", "public", None) + ) + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "missing_table", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "TableNotFoundError" + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_unexpected_error( + self, mock_command_class, mcp_server + ): + """Unexpected exceptions are caught and returned as InternalError.""" + mock_command = MagicMock() + mock_command.run.side_effect = RuntimeError("DB connection lost") + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "public", + "table_name": "orders", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "InternalError" + assert "DB connection lost" in data["error"] + + @pytest.mark.asyncio + async def test_create_dataset_missing_required_fields(self, mcp_server): + """Missing required fields raise a validation error before the tool runs.""" + async with Client(mcp_server) as client: + with pytest.raises(ToolError): + await client.call_tool( + "create_dataset", + { + "request": { + # database_id and table_name are omitted intentionally + "schema": "public", + } + }, + ) + + @patch("superset.mcp_service.dataset.tool.create_dataset.CreateDatasetCommand") + @pytest.mark.asyncio + async def test_create_dataset_returns_full_dataset_info( + self, mock_command_class, mcp_server + ): + """The returned DatasetInfo includes columns, metrics, and all core fields.""" + mock_dataset = _make_mock_dataset( + dataset_id=99, table_name="sales", schema="dw" + ) + + col = MagicMock() + col.column_name = "amount" + col.verbose_name = "Amount" + col.type = "NUMERIC" + col.is_dttm = False + col.groupby = True + col.filterable = True + col.description = "Sale amount" + mock_dataset.columns = [col] + + metric = MagicMock() + metric.metric_name = "total_sales" + metric.verbose_name = "Total Sales" + metric.expression = "SUM(amount)" + metric.description = "Sum of amounts" + metric.d3format = None + mock_dataset.metrics = [metric] + + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "schema": "dw", + "table_name": "sales", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 99 + assert data["table_name"] == "sales" + assert data["schema_name"] == "dw" + assert data["is_virtual"] is False + assert len(data["columns"]) == 1 + assert data["columns"][0]["column_name"] == "amount" + assert len(data["metrics"]) == 1 + assert data["metrics"][0]["metric_name"] == "total_sales"
