This is an automated email from the ASF dual-hosted git repository. kasiazjc pushed a commit to branch fix/column-select-saved-tab-label in repository https://gitbox.apache.org/repos/asf/superset.git
commit 306120a277528adb4e0c90270dadc6de3f117df2 Author: Amin Ghadersohi <[email protected]> AuthorDate: Fri May 15 00:03:06 2026 +0000 fix(mcp): create_dataset — physical table access check + extended tests - Check database existence and call security_manager.raise_for_access(database, table) before invoking CreateDatasetCommand, mirroring the guard in DatabaseRestApi.table_metadata(). Returns DatabaseNotFoundError or AccessDeniedError respectively. - Consolidate optional-property dict construction to stay within C901 limit. - Add autouse fixture in tests to mock the pre-command access check so existing tests remain focused on command-level behavior. - Add tests: database_not_found, access_denied, no_schema (schema is optional), with_catalog (catalog forwarded to command). --- .../mcp_service/dataset/tool/create_dataset.py | 39 +++++-- .../dataset/tool/test_create_dataset.py | 112 +++++++++++++++++++++ 2 files changed, 142 insertions(+), 9 deletions(-) diff --git a/superset/mcp_service/dataset/tool/create_dataset.py b/superset/mcp_service/dataset/tool/create_dataset.py index 5bb660ca8c4..c2e0c79bee2 100644 --- a/superset/mcp_service/dataset/tool/create_dataset.py +++ b/superset/mcp_service/dataset/tool/create_dataset.py @@ -21,13 +21,16 @@ from typing import Any from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations -from superset.extensions import event_logger +from superset.daos.dataset import DatasetDAO +from superset.exceptions import SupersetSecurityException +from superset.extensions import event_logger, security_manager from superset.mcp_service.dataset.schemas import ( CreateDatasetRequest, DatasetError, DatasetInfo, serialize_dataset_object, ) +from superset.sql.parse import Table logger = logging.getLogger(__name__) @@ -62,6 +65,23 @@ async def create_dataset( % (request.database_id, request.schema, request.table_name) ) + # Verify the database exists and the caller has table-level access before + # registering. Mirrors the check in DatabaseRestApi.table_metadata(). + database = DatasetDAO.get_database_by_id(request.database_id) + if database is None: + await ctx.warning("Database %s not found" % request.database_id) + return DatasetError.create( + error=f"Database {request.database_id} not found", + error_type="DatabaseNotFoundError", + ) + + table = Table(request.table_name, request.schema, request.catalog) + try: + security_manager.raise_for_access(database=database, table=table) + except SupersetSecurityException as exc: + await ctx.warning("Access denied for table %r: %s" % (str(table), str(exc))) + return DatasetError.create(error=str(exc), error_type="AccessDeniedError") + try: from superset.commands.dataset.create import CreateDatasetCommand from superset.commands.dataset.exceptions import ( @@ -72,15 +92,16 @@ async def create_dataset( ) dataset_properties: dict[str, Any] = { - "database": request.database_id, - "table_name": request.table_name, + k: v + for k, v in { + "database": request.database_id, + "table_name": request.table_name, + "schema": request.schema, + "catalog": request.catalog, + "owners": request.owners, + }.items() + if v is not None } - if request.schema is not None: - dataset_properties["schema"] = request.schema - if request.catalog is not None: - dataset_properties["catalog"] = request.catalog - if request.owners is not None: - dataset_properties["owners"] = request.owners with event_logger.log_context(action="mcp.create_dataset.create"): dataset = CreateDatasetCommand(dataset_properties).run() diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py index 21e3bec329f..561a6664dd9 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_create_dataset.py @@ -32,6 +32,8 @@ logger = logging.getLogger(__name__) # Patch at source so lazy imports inside the tool function are intercepted. _CMD_PATH = "superset.commands.dataset.create.CreateDatasetCommand" +_DAO_PATH = "superset.mcp_service.dataset.tool.create_dataset.DatasetDAO" +_SEC_PATH = "superset.mcp_service.dataset.tool.create_dataset.security_manager" def _make_mock_dataset( @@ -91,6 +93,19 @@ def mock_auth(): class TestCreateDataset: """Tests for the create_dataset MCP tool.""" + @pytest.fixture(autouse=True) + def mock_dao_and_security(self): + """Default: valid database exists and access is granted. + + Patches the pre-command access check so individual tests that only care + about command behavior don't need to replicate this setup. + """ + with patch(_DAO_PATH) as mock_dao, patch(_SEC_PATH) as mock_sec: + mock_dao.get_database_by_id.return_value = MagicMock( + id=1, database_name="test_db" + ) + yield mock_dao, mock_sec + @patch(_CMD_PATH) @pytest.mark.asyncio async def test_create_dataset_success(self, mock_command_class, mcp_server): @@ -319,3 +334,100 @@ class TestCreateDataset: assert data["columns"][0]["column_name"] == "amount" assert len(data["metrics"]) == 1 assert data["metrics"][0]["metric_name"] == "total_sales" + + @pytest.mark.asyncio + async def test_create_dataset_database_not_found( + self, mock_dao_and_security, mcp_server + ): + """Returns DatabaseNotFoundError when the database_id does not exist.""" + mock_dao, _ = mock_dao_and_security + mock_dao.get_database_by_id.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + {"request": {"database_id": 999, "table_name": "orders"}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "DatabaseNotFoundError" + assert "999" in data["error"] + + @pytest.mark.asyncio + async def test_create_dataset_access_denied( + self, mock_dao_and_security, mcp_server + ): + """Returns AccessDeniedError when the caller lacks table-level access.""" + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetSecurityException + + _, mock_sec = mock_dao_and_security + mock_sec.raise_for_access.side_effect = SupersetSecurityException( + SupersetError( + message="Access denied", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + {"request": {"database_id": 1, "table_name": "secret_table"}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "AccessDeniedError" + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_no_schema( + self, mock_command_class, mock_dao_and_security, mcp_server + ): + """schema is optional; omitting it does not pass it to the command.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + {"request": {"database_id": 1, "table_name": "orders"}}, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + + call_kwargs = mock_command_class.call_args[0][0] + assert "schema" not in call_kwargs + + @patch(_CMD_PATH) + @pytest.mark.asyncio + async def test_create_dataset_with_catalog( + self, mock_command_class, mock_dao_and_security, mcp_server + ): + """catalog is forwarded to CreateDatasetCommand when provided.""" + mock_dataset = _make_mock_dataset() + mock_command = MagicMock() + mock_command.run.return_value = mock_dataset + mock_command_class.return_value = mock_command + + async with Client(mcp_server) as client: + result = await client.call_tool( + "create_dataset", + { + "request": { + "database_id": 1, + "table_name": "orders", + "catalog": "prod_catalog", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 42 + + call_kwargs = mock_command_class.call_args[0][0] + assert call_kwargs["catalog"] == "prod_catalog" + assert "schema" not in call_kwargs
