This is an automated email from the ASF dual-hosted git repository. aminghadersohi pushed a commit to branch mcp-reports-99978 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 1f25b31c879085079fcaaf17c15027d31a8a3052 Author: Amin Ghadersohi <[email protected]> AuthorDate: Wed May 20 22:41:23 2026 +0000 feat(mcp): add list and get tools for alerts and reports Adds list_reports and get_report_info MCP tools under a new superset/mcp_service/report/ domain, following the canonical database domain pattern. Includes unit tests and app.py registration. --- superset/mcp_service/app.py | 8 + superset/mcp_service/report/__init__.py | 16 + superset/mcp_service/report/schemas.py | 286 ++++++++++++++++++ superset/mcp_service/report/tool/__init__.py | 24 ++ .../mcp_service/report/tool/get_report_info.py | 119 ++++++++ superset/mcp_service/report/tool/list_reports.py | 166 ++++++++++ tests/unit_tests/mcp_service/report/__init__.py | 16 + .../unit_tests/mcp_service/report/tool/__init__.py | 16 + .../mcp_service/report/tool/test_report_tools.py | 335 +++++++++++++++++++++ 9 files changed, 986 insertions(+) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 0a68d168a07..ead45f5bea6 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -123,6 +123,10 @@ Database Connections: - list_databases: List database connections with advanced filters (1-based pagination) - get_database_info: Get detailed database connection info by ID (backend, capabilities) +Alerts & Reports: +- list_reports: List alerts and reports with filtering and search (1-based pagination) +- get_report_info: Get detailed alert/report schedule info by ID + Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) @@ -620,6 +624,10 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, ) +from superset.mcp_service.report.tool import ( # noqa: F401, E402 + get_report_info, + list_reports, +) from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 execute_sql, open_sql_lab_with_context, diff --git a/superset/mcp_service/report/__init__.py b/superset/mcp_service/report/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/superset/mcp_service/report/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/report/schemas.py b/superset/mcp_service/report/schemas.py new file mode 100644 index 00000000000..7a5f81fd6ed --- /dev/null +++ b/superset/mcp_service/report/schemas.py @@ -0,0 +1,286 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for report (alerts & reports) related responses. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +import humanize +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.privacy import filter_user_directory_fields +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + + +class ReportFilter(ColumnOperator): + """ + Filter object for report listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "name", + "type", + "active", + "dashboard_id", + "chart_id", + ] = Field( + ..., + description="Column to filter on. Use get_schema(model_type='report') for " + "available filter columns.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. Use get_schema(model_type='report') for " + "available operators.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class ReportInfo(BaseModel): + id: int | None = Field(None, description="Report/Alert ID") + name: str | None = Field(None, description="Report/Alert name") + description: str | None = Field(None, description="Report/Alert description") + type: str | None = Field(None, description="Schedule type: 'Alert' or 'Report'") + active: bool | None = Field(None, description="Whether the schedule is active") + crontab: str | None = Field(None, description="Cron expression for scheduling") + dashboard_id: int | None = Field( + None, description="Associated dashboard ID, if any" + ) + chart_id: int | None = Field(None, description="Associated chart ID, if any") + owners: List[Any] | None = Field( + None, description="List of owners (filtered by privacy controls)" + ) + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + """Filter fields based on serialization context. + + If context contains 'select_columns', only include those fields. + Otherwise, include all fields (default behavior). + """ + data = filter_user_directory_fields(serializer(self)) + + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + + return data + + +class ReportList(BaseModel): + reports: List[ReportInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field( + default_factory=list, + description="Requested columns for the response", + ) + columns_loaded: List[str] = Field( + default_factory=list, + description="Columns that were actually loaded for each report", + ) + columns_available: List[str] = Field( + default_factory=list, + description="All columns available for selection via select_columns parameter", + ) + sortable_columns: List[str] = Field( + default_factory=list, + description="Columns that can be used with order_column parameter", + ) + filters_applied: List[ReportFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListReportsRequest(MetadataCacheControl): + """Request schema for list_reports.""" + + filters: Annotated[ + List[ReportFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' " + "properties. Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="List of columns to select. Defaults to common columns if not " + "specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against report fields. Cannot " + "be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="desc", description="Direction to order results ('asc' or 'desc')" + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[ReportFilter]: + """Accept both JSON string and list of objects.""" + return parse_json_or_model_list(v, ReportFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + """Accept JSON array, list, or comma-separated string.""" + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListReportsRequest": + """Prevent using both search and filters simultaneously.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class ReportError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "ReportError": + """Create a standardized ReportError with timestamp.""" + from datetime import datetime, timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetReportInfoRequest(MetadataCacheControl): + """Request schema for get_report_info — identifier is a numeric ID only.""" + + identifier: Annotated[ + int, + Field(description="Report/Alert numeric ID"), + ] + + +def _humanize_timestamp(dt: datetime | None) -> str | None: + """Convert a datetime to a humanized string like '2 hours ago'.""" + if dt is None: + return None + now = datetime.now(dt.tzinfo) if dt.tzinfo else datetime.now() + return humanize.naturaltime(now - dt) + + +def serialize_report_object(report: Any) -> ReportInfo | None: + if not report: + return None + + return ReportInfo( + id=getattr(report, "id", None), + name=getattr(report, "name", None), + description=getattr(report, "description", None), + type=getattr(report, "type", None), + active=getattr(report, "active", None), + crontab=getattr(report, "crontab", None), + dashboard_id=getattr(report, "dashboard_id", None), + chart_id=getattr(report, "chart_id", None), + owners=getattr(report, "owners", None), + changed_on=getattr(report, "changed_on", None), + changed_on_humanized=_humanize_timestamp(getattr(report, "changed_on", None)), + created_on=getattr(report, "created_on", None), + created_on_humanized=_humanize_timestamp(getattr(report, "created_on", None)), + ) diff --git a/superset/mcp_service/report/tool/__init__.py b/superset/mcp_service/report/tool/__init__.py new file mode 100644 index 00000000000..91a7d931615 --- /dev/null +++ b/superset/mcp_service/report/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_report_info import get_report_info +from .list_reports import list_reports + +__all__ = [ + "list_reports", + "get_report_info", +] diff --git a/superset/mcp_service/report/tool/get_report_info.py b/superset/mcp_service/report/tool/get_report_info.py new file mode 100644 index 00000000000..234082cb88d --- /dev/null +++ b/superset/mcp_service/report/tool/get_report_info.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get report info FastMCP tool. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.report.schemas import ( + GetReportInfoRequest, + ReportError, + ReportInfo, + serialize_report_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="ReportSchedule", + annotations=ToolAnnotations( + title="Get report info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_report_info( + request: GetReportInfoRequest, ctx: Context +) -> ReportInfo | ReportError: + """Get alert or report schedule metadata by numeric ID. + + Returns schedule configuration including type (Alert/Report), active + status, cron expression, and associated dashboard or chart. + + IMPORTANT FOR LLM CLIENTS: + - Use numeric ID (e.g., 123) + - To find a report ID, use the list_reports tool first + + Example usage: + ```json + { + "identifier": 1 + } + ``` + """ + await ctx.info( + "Retrieving report information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.daos.report import ReportScheduleDAO + + with event_logger.log_context(action="mcp.get_report_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=ReportScheduleDAO, + output_schema=ReportInfo, + error_schema=ReportError, + serializer=serialize_report_object, + supports_slug=False, + logger=logger, + ) + + result = get_tool.run_tool(request.identifier) + + if isinstance(result, ReportInfo): + await ctx.info( + "Report information retrieved successfully: " + "report_id=%s, name=%s, type=%s" + % ( + result.id, + result.name, + result.type, + ) + ) + else: + await ctx.warning( + "Report retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Report information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return ReportError( + error=f"Failed to get report info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/report/tool/list_reports.py b/superset/mcp_service/report/tool/list_reports.py new file mode 100644 index 00000000000..06a173ef23a --- /dev/null +++ b/superset/mcp_service/report/tool/list_reports.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List reports (alerts & reports) FastMCP tool. +""" + +import logging +from typing import TYPE_CHECKING + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +if TYPE_CHECKING: + from superset.reports.models import ReportSchedule + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.report.schemas import ( + ListReportsRequest, + ReportError, + ReportFilter, + ReportInfo, + ReportList, + serialize_report_object, +) + +logger = logging.getLogger(__name__) + +DEFAULT_REPORT_COLUMNS = ["id", "name", "type", "active", "crontab"] +SORTABLE_REPORT_COLUMNS = ["id", "name", "type", "active", "changed_on", "created_on"] +ALL_REPORT_COLUMNS = [ + "id", + "name", + "description", + "type", + "active", + "crontab", + "dashboard_id", + "chart_id", + "changed_on", + "created_on", +] + +_DEFAULT_LIST_REPORTS_REQUEST = ListReportsRequest() + + +@tool( + tags=["core"], + class_permission_name="ReportSchedule", + annotations=ToolAnnotations( + title="List reports", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_reports( + request: ListReportsRequest | None = None, + ctx: Context | None = None, +) -> ReportList | ReportError: + """List alerts and reports with filtering and search. + + Returns schedule metadata including name, type (Alert/Report), active + status, and cron expression. + + Sortable columns for order_column: id, name, type, active, changed_on, + created_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_reports") + + request = request or _DEFAULT_LIST_REPORTS_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing reports: page=%s, page_size=%s, search=%s" + % ( + request.page, + request.page_size, + request.search, + ) + ) + await ctx.debug( + "Report listing parameters: filters=%s, order_column=%s, " + "order_direction=%s, select_columns=%s" + % ( + request.filters, + request.order_column, + request.order_direction, + request.select_columns, + ) + ) + + try: + from superset.daos.report import ReportScheduleDAO + + def _serialize_report( + obj: "ReportSchedule | None", cols: list[str] | None + ) -> ReportInfo | None: + return serialize_report_object(obj) + + list_tool = ModelListCore( + dao_class=ReportScheduleDAO, + output_schema=ReportInfo, + item_serializer=_serialize_report, + filter_type=ReportFilter, + default_columns=DEFAULT_REPORT_COLUMNS, + search_columns=["name", "description"], + list_field_name="reports", + output_list_schema=ReportList, + all_columns=ALL_REPORT_COLUMNS, + sortable_columns=SORTABLE_REPORT_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_reports.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Reports listed successfully: count=%s, total_count=%s, total_pages=%s" + % ( + len(result.reports) if hasattr(result, "reports") else 0, + getattr(result, "total_count", None), + getattr(result, "total_pages", None), + ) + ) + + columns_to_filter = result.columns_requested + with event_logger.log_context(action="mcp.list_reports.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "Report listing failed: page=%s, page_size=%s, error=%s, error_type=%s" + % ( + request.page, + request.page_size, + str(e), + type(e).__name__, + ) + ) + raise diff --git a/tests/unit_tests/mcp_service/report/__init__.py b/tests/unit_tests/mcp_service/report/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/report/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/report/tool/__init__.py b/tests/unit_tests/mcp_service/report/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/report/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/report/tool/test_report_tools.py b/tests/unit_tests/mcp_service/report/tool/test_report_tools.py new file mode 100644 index 00000000000..b2bb9be992c --- /dev/null +++ b/tests/unit_tests/mcp_service/report/tool/test_report_tools.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.report.schemas import ListReportsRequest, ReportFilter +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def create_mock_report( + report_id: int = 1, + name: str = "Daily Sales Report", + report_type: str = "Report", + active: bool = True, + crontab: str = "0 9 * * *", + description: str = "A daily report", + dashboard_id: int | None = None, + chart_id: int | None = None, +) -> MagicMock: + """Factory function to create mock report objects with sensible defaults.""" + report = MagicMock() + report.id = report_id + report.name = name + report.type = report_type + report.active = active + report.crontab = crontab + report.description = description + report.dashboard_id = dashboard_id + report.chart_id = chart_id + report.owners = [] + report.changed_on = None + report.created_on = None + return report + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +class TestReportFilterSchema: + """Tests for ReportFilter schema — filterable columns.""" + + def test_valid_filter_name(self): + f = ReportFilter(col="name", opr="eq", value="My Report") + assert f.col == "name" + + def test_valid_filter_type(self): + f = ReportFilter(col="type", opr="eq", value="Alert") + assert f.col == "type" + + def test_valid_filter_active(self): + f = ReportFilter(col="active", opr="eq", value=True) + assert f.col == "active" + + def test_valid_filter_dashboard_id(self): + f = ReportFilter(col="dashboard_id", opr="eq", value=1) + assert f.col == "dashboard_id" + + def test_valid_filter_chart_id(self): + f = ReportFilter(col="chart_id", opr="eq", value=42) + assert f.col == "chart_id" + + def test_invalid_filter_column_rejected(self): + """Columns not in the Literal set must be rejected.""" + with pytest.raises(ValidationError): + ReportFilter(col="not_a_real_column", opr="eq", value=1) + + def test_created_by_fk_is_rejected(self): + """created_by_fk is not a public filter column.""" + with pytest.raises(ValidationError): + ReportFilter(col="created_by_fk", opr="eq", value=1) + + +def test_list_reports_request_accepts_valid_fields(): + request = ListReportsRequest(page=1, page_size=10) + assert request.page == 1 + assert request.page_size == 10 + + +def test_list_reports_request_rejects_search_and_filters_together(): + with pytest.raises(ValidationError): + ListReportsRequest( + search="my report", + filters=[{"col": "active", "opr": "eq", "value": True}], + ) + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_basic(mock_list, mcp_server): + """Test basic report listing functionality.""" + report = create_mock_report() + mock_list.return_value = ([report], 1) + + async with Client(mcp_server) as client: + request = ListReportsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_reports", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["reports"] is not None + assert len(data["reports"]) == 1 + assert data["reports"][0]["id"] == 1 + assert data["reports"][0]["name"] == "Daily Sales Report" + assert data["reports"][0]["type"] == "Report" + assert data["reports"][0]["active"] is True + assert data["reports"][0]["crontab"] == "0 9 * * *" + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_with_search(mock_list, mcp_server): + """Test report listing with search functionality.""" + report = create_mock_report(name="Weekly Alert") + mock_list.return_value = ([report], 1) + + async with Client(mcp_server) as client: + request = ListReportsRequest(page=1, page_size=10, search="Weekly") + result = await client.call_tool( + "list_reports", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["reports"] is not None + assert len(data["reports"]) == 1 + assert data["reports"][0]["name"] == "Weekly Alert" + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_with_type_filter(mock_list, mcp_server): + """Test report listing filtered by type.""" + report = create_mock_report(report_type="Alert") + mock_list.return_value = ([report], 1) + + async with Client(mcp_server) as client: + request = ListReportsRequest( + page=1, + page_size=10, + filters=[{"col": "type", "opr": "eq", "value": "Alert"}], + ) + result = await client.call_tool( + "list_reports", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert len(data["reports"]) == 1 + assert data["reports"][0]["type"] == "Alert" + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_does_not_expose_owners(mock_list, mcp_server): + """Test that owners field is stripped by privacy controls.""" + report = create_mock_report() + mock_list.return_value = ([report], 1) + + async with Client(mcp_server) as client: + request = ListReportsRequest( + page=1, + page_size=10, + select_columns=["id", "name", "owners"], + ) + result = await client.call_tool( + "list_reports", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + # owners is filtered by USER_DIRECTORY_FIELDS + assert "owners" not in data.get("columns_requested", []) + assert "owners" not in data.get("columns_loaded", []) + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_empty_results(mock_list, mcp_server): + """Test report listing with no results.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + request = ListReportsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_reports", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["reports"] == [] + assert data["count"] == 0 + assert data["total_count"] == 0 + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_api_error(mock_list, mcp_server): + """Test error handling when DAO raises an exception.""" + mock_list.side_effect = ToolError("Report DAO error") + + async with Client(mcp_server) as client: + request = ListReportsRequest(page=1, page_size=10) + with pytest.raises(ToolError) as excinfo: # noqa: PT012 + await client.call_tool("list_reports", {"request": request.model_dump()}) + assert "Report DAO error" in str(excinfo.value) + + +@patch("superset.daos.report.ReportScheduleDAO.list") [email protected] +async def test_list_reports_without_request_uses_defaults(mock_list, mcp_server): + """list_reports with no request payload should use default parameters.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_reports", {}) + data = json.loads(result.content[0].text) + assert data["reports"] == [] + assert data["page"] == 1 + + +@patch("superset.daos.report.ReportScheduleDAO.find_by_id") [email protected] +async def test_get_report_info_basic(mock_find, mcp_server): + """Test basic get report info functionality.""" + report = create_mock_report() + mock_find.return_value = report + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_report_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["name"] == "Daily Sales Report" + assert data["type"] == "Report" + assert data["active"] is True + assert data["crontab"] == "0 9 * * *" + assert "owners" not in data + + +@patch("superset.daos.report.ReportScheduleDAO.find_by_id") [email protected] +async def test_get_report_info_alert_type(mock_find, mcp_server): + """Test get report info for an Alert type schedule.""" + report = create_mock_report(report_type="Alert", name="Revenue Alert") + mock_find.return_value = report + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_report_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["type"] == "Alert" + assert data["name"] == "Revenue Alert" + + +@patch("superset.daos.report.ReportScheduleDAO.find_by_id") [email protected] +async def test_get_report_info_not_found(mock_find, mcp_server): + """Test get report info when report does not exist.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_report_info", {"request": {"identifier": 999}} + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + + +@patch("superset.daos.report.ReportScheduleDAO.find_by_id") [email protected] +async def test_get_report_info_with_dashboard(mock_find, mcp_server): + """Test get report info with associated dashboard.""" + report = create_mock_report(dashboard_id=42) + mock_find.return_value = report + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_report_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["dashboard_id"] == 42 + assert data["chart_id"] is None + + +@patch("superset.daos.report.ReportScheduleDAO.find_by_id") [email protected] +async def test_get_report_info_with_chart(mock_find, mcp_server): + """Test get report info with associated chart.""" + report = create_mock_report(chart_id=7) + mock_find.return_value = report + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_report_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + assert data["chart_id"] == 7 + assert data["dashboard_id"] is None
