This is an automated email from the ASF dual-hosted git repository. aminghadersohi pushed a commit to branch mcp-rls-plugins-99978 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 3d528921baba787f786fe1c2bd4a557772596e43 Author: Mehmet Salih Yavuz <[email protected]> AuthorDate: Thu May 21 15:59:09 2026 +0300 feat(mcp): add find_users tool and owner filter columns for listings (#39679) Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]> --- superset/mcp_service/app.py | 17 ++ superset/mcp_service/chart/schemas.py | 22 +- superset/mcp_service/chart/tool/list_charts.py | 7 +- superset/mcp_service/common/schema_discovery.py | 8 +- superset/mcp_service/dashboard/schemas.py | 23 +- .../mcp_service/dashboard/tool/list_dashboards.py | 9 +- superset/mcp_service/database/schemas.py | 9 +- superset/mcp_service/dataset/schemas.py | 17 +- superset/mcp_service/dataset/tool/list_datasets.py | 7 +- superset/mcp_service/mcp_core.py | 15 +- superset/mcp_service/privacy.py | 13 +- superset/mcp_service/system/schemas.py | 84 ++++++- superset/mcp_service/system/tool/__init__.py | 2 + superset/mcp_service/system/tool/find_users.py | 101 ++++++++ .../mcp_service/dataset/tool/test_dataset_tools.py | 18 +- .../mcp_service/system/tool/test_find_users.py | 257 +++++++++++++++++++++ .../system/tool/test_get_current_user.py | 66 ++++-- .../mcp_service/system/tool/test_get_schema.py | 40 +++- 18 files changed, 626 insertions(+), 89 deletions(-) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 8589a51d156..1869c6eeea9 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -158,6 +158,7 @@ Schema Discovery: System Information: - get_instance_info: Get instance-wide statistics, metadata, and current user identity +- find_users: Resolve a person's name to user IDs for use as a filter value - health_check: Simple health check tool (takes NO parameters, call without arguments) - generate_bug_report: Build a PII-sanitized bug report to send to Preset support (use when the user says the MCP is broken or asks how to report an issue) @@ -199,6 +200,16 @@ Some tools do not use a request wrapper, so follow each tool's schema Recommended Workflows: +To filter dashboards/charts/datasets by a person ("show me what <name> is working on"): +1. find_users(request={{"query": "<name>"}}) -> resolve to user IDs +2. Pick the matching user.id from the response +3. list_dashboards(request={{"filters": [ + {{"col": "created_by_fk", "opr": "eq", "value": <id>}} + ]}}) — same shape for list_charts / list_datasets. + (use changed_by_fk for "last modified by", or "in" with a list of IDs for + multiple matches). Do NOT pass the person's name as the search parameter — + search matches titles, not people. + To add a chart to an existing dashboard: 1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly - If permission_denied=True is returned: inform the user they lack edit rights, @@ -371,6 +382,11 @@ Input format: contact details, roles, admin status, ownership, or access-list information. - Do NOT infer access-list answers from dashboard metadata such as published status, role restrictions, empty owner lists, or schema fields. +- find_users is sanctioned ONLY for resolving a name the user supplied into a + user ID for filtering (e.g., "what is <name> working on" -> filter + list_dashboards by created_by_fk). Do NOT use find_users to answer "who owns + X", "who can access X", "is <name> an admin", or to enumerate the directory. + Never return find_users output to the user verbatim. - Do NOT use execute_sql to query user, role, owner, or access-list tables for this information. - You may reference the current user's own identity details when appropriate, such @@ -649,6 +665,7 @@ from superset.mcp_service.system import ( # noqa: F401, E402 resources as system_resources, ) from superset.mcp_service.system.tool import ( # noqa: F401, E402 + find_users, generate_bug_report, get_instance_info, get_schema, diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 07baeab9ba1..81ddd56a540 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -23,7 +23,7 @@ from __future__ import annotations import difflib from datetime import datetime -from typing import Annotated, Any, Dict, List, Literal, Protocol +from typing import Annotated, Any, cast, Dict, List, Literal, Protocol import humanize from pydantic import ( @@ -146,7 +146,7 @@ class ChartInfo(BaseModel): ), ) form_data_key: str | None = Field( - None, + default=None, description=( "Cache key used to retrieve unsaved form_data. When present, indicates " "the form_data came from cache (unsaved edits) rather than the saved chart." @@ -523,17 +523,18 @@ class ChartFilter(ColumnOperator): value: The value to filter by (type depends on col and opr). """ - col: Literal[ + col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] "slice_name", "viz_type", "datasource_name", + "created_by_fk", + "changed_by_fk", ] = Field( ..., - description=( - "Column to filter on. Valid values: 'slice_name', 'viz_type', " - "'datasource_name'. Other column names are not valid filter columns " - "and will cause a validation error." - ), + description="Column to filter on. Use get_schema(model_type='chart') for " + "available filter columns. To filter by a person, first call find_users " + "to resolve a name to a user ID, then filter by created_by_fk or " + "changed_by_fk with that integer ID.", ) opr: ColumnOperatorEnum = Field( ..., @@ -1538,7 +1539,10 @@ class ListChartsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheControl): """ from superset.mcp_service.utils.schema_utils import parse_json_or_model_list - return parse_json_or_model_list(v, ChartFilter, "filters") + return cast( + List[ChartFilter], + parse_json_or_model_list(v, ChartFilter, "filters"), + ) @field_validator("select_columns", mode="before") @classmethod diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index bfb2c564540..c682e3fe63d 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -104,11 +104,16 @@ async def list_charts( list_charts(search="revenue", page=1) # DO NOT DO THIS Valid filter columns for ``filters[].col``: - ``slice_name``, ``viz_type``, ``datasource_name`` + ``slice_name``, ``viz_type``, ``datasource_name``, + ``created_by_fk``, ``changed_by_fk`` Sortable columns for ``order_column``: ``id``, ``slice_name``, ``viz_type``, ``description``, ``changed_on``, ``created_on`` + + To filter by a person, call find_users to resolve the name to a user ID, + then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq", + "value": <id>}] (or "changed_by_fk"). Do not pass the name as search. """ request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True) await ctx.info( diff --git a/superset/mcp_service/common/schema_discovery.py b/superset/mcp_service/common/schema_discovery.py index 71d7dc22f81..ca7b594dbbf 100644 --- a/superset/mcp_service/common/schema_discovery.py +++ b/superset/mcp_service/common/schema_discovery.py @@ -37,10 +37,12 @@ class ColumnMetadata(BaseModel): """Metadata for a selectable column.""" name: str = Field(..., description="Column name to use in select_columns") - description: str | None = Field(None, description="Column description") - type: str | None = Field(None, description="Data type (str, int, datetime, etc.)") + description: str | None = Field(default=None, description="Column description") + type: str | None = Field( + default=None, description="Data type (str, int, datetime, etc.)" + ) is_default: bool = Field( - False, description="Whether this column is included by default" + default=False, description="Whether this column is included by default" ) diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py index 68a680863e5..9b889ff27ee 100644 --- a/superset/mcp_service/dashboard/schemas.py +++ b/superset/mcp_service/dashboard/schemas.py @@ -67,7 +67,7 @@ from __future__ import annotations import logging from datetime import datetime -from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING +from typing import Annotated, Any, cast, Dict, List, Literal, TYPE_CHECKING import humanize from pydantic import ( @@ -169,16 +169,20 @@ class DashboardFilter(ColumnOperator): value: The value to filter by (type depends on col and opr). """ - col: Literal[ + col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] "dashboard_title", "published", "favorite", + "created_by_fk", + "changed_by_fk", ] = Field( ..., description=( - "Column to filter on. Valid values: 'dashboard_title', 'published', " - "'favorite'. Other column names are not valid filter columns and will " - "cause a validation error." + "Column to filter on. Use " + "get_schema(model_type='dashboard') for available " + "filter columns. To filter by a person, first call find_users to " + "resolve a name to a user ID, then filter by created_by_fk or " + "changed_by_fk with that integer ID." ), ) opr: ColumnOperatorEnum = Field( @@ -223,7 +227,10 @@ class ListDashboardsRequest(OwnedByMeMixin, CreatedByMeMixin, MetadataCacheContr """ from superset.mcp_service.utils.schema_utils import parse_json_or_model_list - return parse_json_or_model_list(v, DashboardFilter, "filters") + return cast( + List[DashboardFilter], + parse_json_or_model_list(v, DashboardFilter, "filters"), + ) @field_validator("select_columns", mode="before") @classmethod @@ -392,14 +399,14 @@ class DashboardInfo(BaseModel): # Fields for permalink/filter state support permalink_key: str | None = Field( - None, + default=None, description=( "Permalink key used to retrieve filter state. When present, indicates " "the filter_state came from a permalink rather than the default dashboard." ), ) filter_state: Dict[str, Any] | None = Field( - None, + default=None, description=( "Filter state from permalink. Contains dataMask (native filter values), " "activeTabs, anchor, and urlParams. When present, represents the actual " diff --git a/superset/mcp_service/dashboard/tool/list_dashboards.py b/superset/mcp_service/dashboard/tool/list_dashboards.py index 8c479df2753..19f967e801f 100644 --- a/superset/mcp_service/dashboard/tool/list_dashboards.py +++ b/superset/mcp_service/dashboard/tool/list_dashboards.py @@ -98,11 +98,18 @@ async def list_dashboards( list_dashboards(search="sales", page=1) # DO NOT DO THIS Valid filter columns for ``filters[].col``: - ``dashboard_title``, ``published``, ``favorite`` + ``dashboard_title``, ``published``, ``favorite``, + ``created_by_fk``, ``changed_by_fk`` Sortable columns for ``order_column``: ``id``, ``dashboard_title``, ``slug``, ``published``, ``changed_on``, ``created_on`` + + To filter by a person (e.g. "dashboards Maxime is working on"), do NOT pass + the name as the search parameter — search matches titles and slugs only. + Instead, call find_users to resolve the name to a user ID, then pass it as + a filter: filters=[{"col": "created_by_fk", "opr": "eq", "value": <id>}] + (or "changed_by_fk" for "last modified by"). """ request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True) await ctx.info( diff --git a/superset/mcp_service/database/schemas.py b/superset/mcp_service/database/schemas.py index 020421550cc..00a5ab158fe 100644 --- a/superset/mcp_service/database/schemas.py +++ b/superset/mcp_service/database/schemas.py @@ -22,7 +22,7 @@ Pydantic schemas for database-related responses from __future__ import annotations from datetime import datetime -from typing import Annotated, Any, Dict, List, Literal +from typing import Annotated, Any, cast, Dict, List, Literal import humanize from pydantic import ( @@ -58,7 +58,7 @@ class DatabaseFilter(ColumnOperator): value: The value to filter by (type depends on col and opr). """ - col: Literal[ + col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] "database_name", "expose_in_sqllab", "allow_file_upload", @@ -242,7 +242,10 @@ class ListDatabasesRequest(CreatedByMeMixin, MetadataCacheControl): @classmethod def parse_filters(cls, v: Any) -> List[DatabaseFilter]: """Accept both JSON string and list of objects.""" - return parse_json_or_model_list(v, DatabaseFilter, "filters") + return cast( + List[DatabaseFilter], + parse_json_or_model_list(v, DatabaseFilter, "filters"), + ) @field_validator("select_columns", mode="before") @classmethod diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index b5aa475ff2e..ce7a60c86fb 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -65,17 +65,18 @@ class DatasetFilter(ColumnOperator): value: The value to filter by (type depends on col and opr). """ - col: Literal[ + col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride] "table_name", "schema", "database_name", + "created_by_fk", + "changed_by_fk", ] = Field( ..., - description=( - "Column to filter on. Valid values: 'table_name', 'schema', " - "'database_name'. Other column names (e.g. 'created_by_fk', 'id') " - "are not valid filter columns and will cause a validation error." - ), + description="Column to filter on. Use get_schema(model_type='dataset') for " + "available filter columns. To filter by a person, first call find_users " + "to resolve a name to a user ID, then filter by created_by_fk or " + "changed_by_fk with that integer ID.", ) opr: ColumnOperatorEnum = Field( ..., @@ -658,7 +659,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: params = None columns = [ TableColumnInfo( - column_name=getattr(col, "column_name", None), + column_name=getattr(col, "column_name", None) or "", verbose_name=getattr(col, "verbose_name", None), type=getattr(col, "type", None), is_dttm=getattr(col, "is_dttm", None), @@ -670,7 +671,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: ] metrics = [ SqlMetricInfo( - metric_name=getattr(metric, "metric_name", None), + metric_name=getattr(metric, "metric_name", None) or "", verbose_name=getattr(metric, "verbose_name", None), expression=getattr(metric, "expression", None), description=getattr(metric, "description", None), diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index 099c15231ea..a1dd552dc48 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -109,10 +109,15 @@ async def list_datasets( list_datasets(search="sales", page=1) # DO NOT DO THIS Valid filter columns for ``filters[].col``: - ``table_name``, ``schema``, ``database_name`` + ``table_name``, ``schema``, ``database_name``, + ``created_by_fk``, ``changed_by_fk`` Sortable columns for ``order_column``: ``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on`` + + To filter by a person, call find_users to resolve the name to a user ID, + then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq", + "value": <id>}] (or "changed_by_fk"). Do not pass the name as search. """ if ctx is None: raise RuntimeError("FastMCP context is required for list_datasets") diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index 3494ec5227a..c5d7eabe21c 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -34,6 +34,7 @@ from superset.mcp_service.privacy import ( filter_user_directory_columns, SELF_REFERENCING_FILTER_COLUMNS, USER_DIRECTORY_FIELDS, + USER_FILTER_FIELDS, ) from superset.mcp_service.system.schemas import PaginationInfo from superset.mcp_service.utils import _is_uuid @@ -314,14 +315,6 @@ class ModelListCore(BaseCore, Generic[L]): has_previous=page > 0, ) - # Build response - def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]: - if hasattr(obj, "model_dump"): - return list(obj.model_dump().keys()) - elif isinstance(obj, dict): - return list(obj.keys()) - return [] - response_kwargs = { self.list_field_name: item_objs, "count": len(item_objs), @@ -595,7 +588,7 @@ class InstanceInfoCore(BaseCore): return counts def _calculate_time_based_metrics( - self, base_counts: Dict[str, int] + self, _base_counts: Dict[str, int] ) -> Dict[str, Dict[str, int]]: """Calculate time-based metrics for recent activity.""" now = datetime.now(timezone.utc) @@ -774,7 +767,9 @@ class ModelGetSchemaCore(BaseCore, Generic[S]): self.default_sort = default_sort self.default_sort_direction = default_sort_direction self.exclude_filter_columns = set(exclude_filter_columns or set()) - self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS) + # Hide user-directory columns from filter discovery, except the small + # set callers may legitimately filter by ID (resolved via find_users). + self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS - USER_FILTER_FIELDS) def _get_filter_columns(self) -> Dict[str, List[str]]: """Get filterable columns and operators from the DAO.""" diff --git a/superset/mcp_service/privacy.py b/superset/mcp_service/privacy.py index 105a9492666..86dc552e789 100644 --- a/superset/mcp_service/privacy.py +++ b/superset/mcp_service/privacy.py @@ -44,13 +44,20 @@ USER_DIRECTORY_FIELDS = frozenset( } ) +# User-directory columns that may be used as filter values (an integer user ID). +# These remain stripped from select_columns, sort, search, and tool responses +# (so the directory itself is never exposed), but list tools may filter rows by +# them when the caller already has an ID — typically resolved via find_users. +USER_FILTER_FIELDS = frozenset({"created_by_fk", "changed_by_fk"}) + # Internal DAO filter column names generated server-side when translating the # created_by_me / owned_by_me boolean flags (see mcp_core._prepend_self_lookup_filters). # These columns are never exposed to LLM callers; they are excluded from the # filters_applied response field to avoid leaking internal implementation details. -SELF_REFERENCING_FILTER_COLUMNS = frozenset( - {"created_by_fk", "owner", "created_by_fk_or_owner"} -) +# Note: ``created_by_fk`` is intentionally excluded — it is also a publicly +# advertised filter column (see USER_FILTER_FIELDS) so callers can filter by a +# user ID resolved via find_users. +SELF_REFERENCING_FILTER_COLUMNS = frozenset({"owner", "created_by_fk_or_owner"}) DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access" DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted" diff --git a/superset/mcp_service/system/schemas.py b/superset/mcp_service/system/schemas.py index 3324c0c2b51..663051624e7 100644 --- a/superset/mcp_service/system/schemas.py +++ b/superset/mcp_service/system/schemas.py @@ -25,9 +25,11 @@ system-level info. from __future__ import annotations from datetime import datetime -from typing import Any +from typing import Annotated, Any, List -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE class HealthCheckResponse(BaseModel): @@ -170,6 +172,84 @@ def serialize_user_object(user: Any) -> UserInfo | None: ) +class FindUsersRequest(BaseModel): + """Request schema for find_users tool. + + Resolves a person's name (or partial name, username, or email) to user IDs + so they can be passed to listing tools as filter values for created_by_fk + or changed_by_fk. This is the only sanctioned path for "show me what + <person> is working on" queries. + """ + + model_config = ConfigDict(extra="forbid") + + query: Annotated[ + str, + Field( + min_length=1, + max_length=200, + description=( + "Substring to match (case-insensitive) against username, " + "first_name, last_name, and email. Required and non-empty: " + "this tool does not enumerate the full user directory." + ), + ), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Maximum number of matches to return (max {MAX_PAGE_SIZE}).", + ), + ] + + @field_validator("query") + @classmethod + def _reject_blank_query(cls, value: str) -> str: + # min_length=1 alone admits whitespace-only strings, which strip to "" and + # produce a "%%" LIKE pattern that matches every user. Strip and require + # at least one non-space character. + stripped = value.strip() + if not stripped: + raise ValueError("query must contain at least one non-whitespace character") + return stripped + + +class UserMatch(BaseModel): + """Minimal user projection returned by find_users. + + Intentionally narrower than UserInfo: only the fields needed to disambiguate + matches and pass an id to created_by_fk / changed_by_fk filters. Email, + active flag, and roles are deliberately excluded to limit identity + exposure through this directory-resolution path. + """ + + id: int | None = None + username: str | None = None + first_name: str | None = None + last_name: str | None = None + + +class FindUsersResponse(BaseModel): + """Response schema for find_users tool.""" + + users: List[UserMatch] = Field( + default_factory=list, + description=( + "Matching users. Pass user.id as the value for created_by_fk or " + "changed_by_fk filters on list_dashboards, list_charts, and " + "list_datasets." + ), + ) + count: int = Field(..., description="Number of users returned in this response.") + truncated: bool = Field( + default=False, + description="True when the query matched more rows than page_size allows.", + ) + + class TagInfo(BaseModel): id: int | None = None name: str | None = None diff --git a/superset/mcp_service/system/tool/__init__.py b/superset/mcp_service/system/tool/__init__.py index f55f3d37ee9..c5725cc6d6e 100644 --- a/superset/mcp_service/system/tool/__init__.py +++ b/superset/mcp_service/system/tool/__init__.py @@ -17,12 +17,14 @@ """System tools for MCP service.""" +from .find_users import find_users from .generate_bug_report import generate_bug_report from .get_instance_info import get_instance_info from .get_schema import get_schema from .health_check import health_check __all__ = [ + "find_users", "generate_bug_report", "health_check", "get_instance_info", diff --git a/superset/mcp_service/system/tool/find_users.py b/superset/mcp_service/system/tool/find_users.py new file mode 100644 index 00000000000..0db21522508 --- /dev/null +++ b/superset/mcp_service/system/tool/find_users.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""find_users MCP tool: resolve a person's name to user IDs for filtering.""" + +import logging + +from fastmcp import Context +from sqlalchemy import or_ +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import db, event_logger, security_manager +from superset.mcp_service.system.schemas import ( + FindUsersRequest, + FindUsersResponse, + UserMatch, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["core"], + annotations=ToolAnnotations( + title="Find users", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def find_users(request: FindUsersRequest, ctx: Context) -> FindUsersResponse: + """Resolve a person's name to user IDs so they can be used as filter values. + + Use this when the caller asks "show me <person>'s dashboards/charts/datasets" + or "what is <person> working on". Take the matching user.id and pass it as + the value for a created_by_fk or changed_by_fk filter on list_dashboards, + list_charts, or list_datasets. + + Matches case-insensitively against username, first_name, last_name, and + email. The query is required and non-empty; this tool does not enumerate + the full user directory. + + Privacy: returning a user's identity here is sanctioned only for resolving + filter values. Do not use the response to answer "who owns X", "who can + access X", or any access-list question — those remain off-limits per the + server instructions. + """ + await ctx.info( + "Resolving user query: query=%s, page_size=%s" + % (request.query, request.page_size) + ) + + user_model = security_manager.user_model + needle = f"%{request.query.strip()}%" + + with event_logger.log_context(action="mcp.find_users.query"): + query = ( + db.session.query(user_model) + .filter( + or_( + user_model.username.ilike(needle), + user_model.first_name.ilike(needle), + user_model.last_name.ilike(needle), + user_model.email.ilike(needle), + ) + ) + .order_by(user_model.username.asc()) + ) + # Fetch one extra row to detect truncation without a separate count query. + rows = query.limit(request.page_size + 1).all() + + truncated = len(rows) > request.page_size + rows = rows[: request.page_size] + + users: list[UserMatch] = [ + UserMatch( + id=getattr(row, "id", None), + username=getattr(row, "username", None), + first_name=getattr(row, "first_name", None), + last_name=getattr(row, "last_name", None), + ) + for row in rows + ] + + await ctx.info( + "Resolved user query: matches=%s, truncated=%s" % (len(users), truncated) + ) + return FindUsersResponse(users=users, count=len(users), truncated=truncated) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py index 25a8a4c0fc6..adc6e73ad03 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py @@ -2043,12 +2043,10 @@ class TestListDatasetsCreatedByMe: with pytest.raises(ValidationError, match="created_by_me"): ListDatasetsRequest(created_by_me=True, search="My tables") - def test_dataset_filter_rejects_created_by_fk(self): - """created_by_fk is not a public filter column; use created_by_me instead.""" - from pydantic import ValidationError - - with pytest.raises(ValidationError): - DatasetFilter(col="created_by_fk", opr="eq", value=1) + def test_dataset_filter_accepts_created_by_fk(self): + """created_by_fk is exposed for person-filtering via find_users.""" + f = DatasetFilter(col="created_by_fk", opr="eq", value=1) + assert f.col == "created_by_fk" class TestListDatasetsOwnedByMe: @@ -2115,14 +2113,10 @@ class TestListDatasetsRequestWrapper: assert f.col == col def test_dataset_filter_invalid_col_raises(self) -> None: - """Column names not in the Literal are rejected with a validation error. - - This guards against LLMs passing ``created_by_fk`` or similar - internal column names that are not exposed as filter fields. - """ + """Column names not in the Literal are rejected with a validation error.""" from pydantic import ValidationError - for bad_col in ("created_by_fk", "id", "database_id", "owner"): + for bad_col in ("id", "database_id", "owner"): with pytest.raises(ValidationError): DatasetFilter(col=bad_col, opr="eq", value="1") diff --git a/tests/unit_tests/mcp_service/system/tool/test_find_users.py b/tests/unit_tests/mcp_service/system/tool/test_find_users.py new file mode 100644 index 00000000000..e666c73db20 --- /dev/null +++ b/tests/unit_tests/mcp_service/system/tool/test_find_users.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for find_users MCP tool and its filter contract.""" + +import importlib +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.system.schemas import FindUsersRequest, FindUsersResponse +from superset.utils import json + +# Import the submodule directly so ``patch.object`` targets the module (not the +# ``find_users`` function that ``tool/__init__.py`` re-exports onto the +# package). The package attribute is the function, so dotted-string patches +# like ``superset.mcp_service.system.tool.find_users.db`` can resolve to the +# function in some import orderings and fail with AttributeError. +find_users_module = importlib.import_module( + "superset.mcp_service.system.tool.find_users" +) + + [email protected] +def mcp_server(): + return mcp + + [email protected](autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _make_user(id_, username, first=None, last=None, email=None, active=True): + """Build a Mock user with the attributes serialize_user_object reads.""" + user = Mock( + spec=["id", "username", "first_name", "last_name", "email", "active", "roles"] + ) + user.id = id_ + user.username = username + user.first_name = first + user.last_name = last + user.email = email + user.active = active + user.roles = [] + return user + + +def _patch_user_query(rows): + """Patch the SQLAlchemy chain used by find_users to return a fixed result set.""" + chain = MagicMock() + chain.filter.return_value = chain + chain.order_by.return_value = chain + chain.limit.return_value = chain + chain.all.return_value = rows + session = MagicMock() + session.query.return_value = chain + return session, chain + + +# --------------------------------------------------------------------------- +# Schema tests +# --------------------------------------------------------------------------- + + +def test_find_users_request_rejects_empty_query(): + with pytest.raises(ValidationError): + FindUsersRequest(query="") + + +def test_find_users_request_rejects_extra_fields(): + with pytest.raises(ValidationError): + FindUsersRequest(query="maxime", random_field="x") + + +def test_find_users_response_default_truncated_false(): + resp = FindUsersResponse(users=[], count=0) + assert resp.truncated is False + + +# --------------------------------------------------------------------------- +# Tool-level tests +# --------------------------------------------------------------------------- + + [email protected] +async def test_find_users_returns_matches(mcp_server): + rows = [ + _make_user( + 7, "maxime", first="Maxime", last="Beauchemin", email="[email protected]" + ) + ] + session, _ = _patch_user_query(rows) + + with ( + patch.object(find_users_module, "db") as mock_db, + patch.object(find_users_module, "security_manager") as mock_sm, + patch.object(find_users_module, "or_") as mock_or, + ): + mock_db.session = session + mock_sm.user_model = MagicMock() + mock_or.return_value = MagicMock() + + async with Client(mcp_server) as client: + result = await client.call_tool( + "find_users", {"request": {"query": "maxime"}} + ) + + data = json.loads(result.content[0].text) + assert data["count"] == 1 + assert data["truncated"] is False + assert data["users"][0]["id"] == 7 + assert data["users"][0]["username"] == "maxime" + assert data["users"][0]["first_name"] == "Maxime" + assert data["users"][0]["last_name"] == "Beauchemin" + # Privacy: minimal projection excludes identity attributes that aren't + # required for filter resolution. Catch regressions on the response shape. + for forbidden in ("email", "active", "roles"): + assert forbidden not in data["users"][0] + # or_ should have been built across the four matched columns + assert mock_or.called + assert len(mock_or.call_args.args) == 4 + + [email protected] +async def test_find_users_truncates_when_more_rows_than_page_size(mcp_server): + # page_size=2 with 3 returned rows -> truncated, response trimmed to 2 + rows = [ + _make_user(1, "a"), + _make_user(2, "b"), + _make_user(3, "c"), + ] + session, chain = _patch_user_query(rows) + + with ( + patch.object(find_users_module, "db") as mock_db, + patch.object(find_users_module, "security_manager") as mock_sm, + patch.object(find_users_module, "or_") as mock_or, + ): + mock_db.session = session + mock_sm.user_model = MagicMock() + mock_or.return_value = MagicMock() + + async with Client(mcp_server) as client: + result = await client.call_tool( + "find_users", {"request": {"query": "a", "page_size": 2}} + ) + + # Tool requested page_size+1 rows for truncation detection + chain.limit.assert_called_with(3) + + data = json.loads(result.content[0].text) + assert data["count"] == 2 + assert data["truncated"] is True + assert [u["id"] for u in data["users"]] == [1, 2] + + [email protected] +async def test_find_users_rejects_empty_query_via_client(mcp_server): + async with Client(mcp_server) as client: + with pytest.raises(ToolError): + await client.call_tool("find_users", {"request": {"query": ""}}) + + [email protected]("blank", [" ", " ", "\t", "\n \t"]) +def test_find_users_request_rejects_whitespace_only_query(blank): + # Whitespace-only queries would strip to "" and produce a LIKE "%%" pattern + # that enumerates the entire user directory. The validator must reject them. + with pytest.raises(ValidationError): + FindUsersRequest(query=blank) + + +def test_find_users_request_strips_query_whitespace(): + # Validator should normalize the stored query so downstream LIKE patterns + # don't carry leading/trailing whitespace. + request = FindUsersRequest(query=" maxime ") + assert request.query == "maxime" + + +# --------------------------------------------------------------------------- +# Filter contract: created_by_fk / changed_by_fk filtering on list tools +# --------------------------------------------------------------------------- + + +@patch("superset.daos.dashboard.DashboardDAO.list") [email protected] +async def test_list_dashboards_passes_created_by_fk_filter_to_dao( + mock_list, mcp_server +): + """list_dashboards should accept created_by_fk filter and forward it.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + await client.call_tool( + "list_dashboards", + { + "request": { + "filters": [{"col": "created_by_fk", "opr": "eq", "value": 7}], + "page": 1, + "page_size": 10, + } + }, + ) + + assert mock_list.called + forwarded_filters = mock_list.call_args.kwargs.get("column_operators") + assert forwarded_filters is not None + assert any( + getattr(f, "col", None) == "created_by_fk" and getattr(f, "value", None) == 7 + for f in forwarded_filters + ) + + +@patch("superset.daos.chart.ChartDAO.list") [email protected] +async def test_list_charts_passes_changed_by_fk_filter_to_dao(mock_list, mcp_server): + """list_charts should accept changed_by_fk filter and forward it.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + await client.call_tool( + "list_charts", + { + "request": { + "filters": [{"col": "changed_by_fk", "opr": "eq", "value": 7}], + "page": 1, + "page_size": 10, + } + }, + ) + + assert mock_list.called + forwarded_filters = mock_list.call_args.kwargs.get("column_operators") + assert forwarded_filters is not None + assert any(getattr(f, "col", None) == "changed_by_fk" for f in forwarded_filters) diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py index e10d534d158..5406568d81e 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_current_user.py @@ -22,6 +22,8 @@ from unittest.mock import Mock, patch import pytest from fastmcp import Client +from fastmcp.client.client import CallToolResult +from mcp.types import TextContent from pydantic import ValidationError from superset.mcp_service.app import mcp @@ -45,6 +47,14 @@ get_schema_module = importlib.import_module( "superset.mcp_service.system.tool.get_schema" ) + +def _result_text(result: CallToolResult) -> str: + """Return the text payload from the first content block of a tool result.""" + block = result.content[0] + assert isinstance(block, TextContent) + return block.text + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -198,7 +208,7 @@ async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_serve {"request": {"model_type": "dataset"}}, ) - data = json.loads(result.content[0].text) + data = json.loads(_result_text(result)) assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE assert data["privacy_scope"] == "data_model" @@ -241,7 +251,7 @@ async def test_get_schema_redacts_chart_data_model_fields(mcp_server): {"request": {"model_type": "chart"}}, ) - data = json.loads(result.content[0].text) + data = json.loads(_result_text(result)) schema_info = data["schema_info"] assert all( column["name"] not in CHART_DATA_MODEL_COLUMNS @@ -389,7 +399,7 @@ class TestGetInstanceInfoCurrentUserViaMCP: async with Client(mcp_server) as client: result = await client.call_tool("get_instance_info", {"request": {}}) - data = json.loads(result.content[0].text) + data = json.loads(_result_text(result)) assert "current_user" in data cu = data["current_user"] assert cu["id"] == 5 @@ -418,7 +428,7 @@ class TestGetInstanceInfoCurrentUserViaMCP: async with Client(mcp_server) as client: result = await client.call_tool("get_instance_info", {"request": {}}) - data = json.loads(result.content[0].text) + data = json.loads(_result_text(result)) assert data["current_user"] is None @pytest.mark.asyncio @@ -444,7 +454,7 @@ class TestGetInstanceInfoCurrentUserViaMCP: async with Client(mcp_server) as client: result = await client.call_tool("get_instance_info", {"request": {}}) - data = json.loads(result.content[0].text) + data = json.loads(_result_text(result)) cu = data["current_user"] assert cu["id"] == 99 assert cu["username"] == "bot" @@ -460,28 +470,50 @@ class TestGetInstanceInfoCurrentUserViaMCP: # --------------------------------------------------------------------------- -def test_chart_filter_rejects_created_by_fk() -> None: - """created_by_fk is not a valid ChartFilter column; use created_by_me instead.""" - with pytest.raises(ValidationError): - ChartFilter(col="created_by_fk", opr="eq", value=42) +def test_chart_filter_rejects_user_directory_columns_other_than_fk() -> None: + """ChartFilter still rejects user-directory columns that expose names.""" + for col in ("created_by_name", "owners", "changed_by"): + with pytest.raises(ValidationError): + ChartFilter.model_validate({"col": col, "opr": "eq", "value": "anything"}) + + +def test_chart_filter_accepts_created_and_changed_by_fk() -> None: + """ChartFilter allows filtering by created_by_fk / changed_by_fk (user IDs).""" + for col in ("created_by_fk", "changed_by_fk"): + f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": 42}) + assert f.col == col def test_chart_filter_rejects_invalid_column(): """Test that ChartFilter rejects invalid column names.""" with pytest.raises(ValidationError): - ChartFilter(col="nonexistent_column", opr="eq", value=42) + ChartFilter.model_validate( + {"col": "nonexistent_column", "opr": "eq", "value": 42} + ) -def test_dashboard_filter_rejects_created_by_fk(): - """created_by_fk is not a valid DashboardFilter column; use created_by_me.""" - with pytest.raises(ValidationError): - DashboardFilter(col="created_by_fk", opr="eq", value=42) +def test_dashboard_filter_rejects_user_directory_columns_other_than_fk() -> None: + """DashboardFilter still rejects user-directory columns that expose names.""" + for col in ("created_by_name", "owners", "changed_by"): + with pytest.raises(ValidationError): + DashboardFilter.model_validate( + {"col": col, "opr": "eq", "value": "anything"} + ) + + +def test_dashboard_filter_accepts_created_and_changed_by_fk() -> None: + """DashboardFilter allows filtering by created_by_fk / changed_by_fk.""" + for col in ("created_by_fk", "changed_by_fk"): + f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": 42}) + assert f.col == col def test_dashboard_filter_rejects_invalid_column(): """Test that DashboardFilter rejects invalid column names.""" with pytest.raises(ValidationError): - DashboardFilter(col="nonexistent_column", opr="eq", value=42) + DashboardFilter.model_validate( + {"col": "nonexistent_column", "opr": "eq", "value": 42} + ) # --------------------------------------------------------------------------- @@ -492,12 +524,12 @@ def test_dashboard_filter_rejects_invalid_column(): def test_chart_filter_existing_columns_still_work(): """Test that pre-existing chart filter columns are not broken.""" for col in ("slice_name", "viz_type", "datasource_name"): - f = ChartFilter(col=col, opr="eq", value="test") + f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": "test"}) assert f.col == col def test_dashboard_filter_existing_columns_still_work(): """Test that pre-existing dashboard filter columns are not broken.""" for col in ("dashboard_title", "published", "favorite"): - f = DashboardFilter(col=col, opr="eq", value="test") + f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": "test"}) assert f.col == col diff --git a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py index 1df8dc2ca48..5a64b6511d5 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_get_schema.py +++ b/tests/unit_tests/mcp_service/system/tool/test_get_schema.py @@ -326,11 +326,19 @@ class TestGetSchemaToolViaClient: async def test_get_schema_omits_user_directory_columns( self, mock_filters, mcp_server ): - """Test that schema discovery does not advertise user/access fields.""" + """Test that schema discovery does not advertise user/access fields. + + created_by_fk and changed_by_fk are intentionally allowed in + filter_columns so callers can filter by user ID resolved via find_users, + but they remain hidden from select_columns and sortable_columns so the + directory itself is never exposed. + """ mock_filters.return_value = { "dashboard_title": ["eq", "ilike"], "owner": ["rel_m_m"], "published": ["eq"], + "created_by_fk": ["eq", "in"], + "changed_by_fk": ["eq", "in"], } async with Client(mcp_server) as client: @@ -352,9 +360,16 @@ class TestGetSchemaToolViaClient: "owner", ): assert field not in select_column_names - assert field not in info["filter_columns"] assert field not in info["sortable_columns"] + # User-name and relationship fields stay out of filter_columns + for field in ("owners", "roles", "created_by", "changed_by", "owner"): + assert field not in info["filter_columns"] + + # ID-only filter columns are advertised so callers can filter via find_users + assert "created_by_fk" in info["filter_columns"] + assert "changed_by_fk" in info["filter_columns"] + @patch("superset.daos.chart.ChartDAO.get_filterable_columns_and_operators") @pytest.mark.asyncio async def test_get_schema_chart_omits_self_referencing_filter_columns( @@ -362,8 +377,9 @@ class TestGetSchemaToolViaClient: ): """Test that chart schema does not advertise self-referencing filter columns. - Even if the DAO returns created_by_fk or owner, they must be excluded so - LLMs cannot discover and use them to enumerate user IDs. + Even if the DAO returns owner or created_by_fk_or_owner, they must be + excluded — these synthetic columns are generated server-side from the + owned_by_me flag and are not directly usable by LLM callers. """ mock_filters.return_value = { "slice_name": ["eq", "ilike"], @@ -381,7 +397,7 @@ class TestGetSchemaToolViaClient: info = data["schema_info"] assert "slice_name" in info["filter_columns"] - for field in ("created_by_fk", "owner", "created_by_fk_or_owner"): + for field in ("owner", "created_by_fk_or_owner"): assert field not in info["filter_columns"] @patch("superset.daos.dataset.DatasetDAO.get_filterable_columns_and_operators") @@ -391,8 +407,9 @@ class TestGetSchemaToolViaClient: ): """Test that dataset schema does not advertise self-referencing filter columns. - Even if the DAO returns created_by_fk or owner, they must be excluded so - LLMs cannot discover and use them to enumerate user IDs. + Even if the DAO returns owner or created_by_fk_or_owner, they must be + excluded — these synthetic columns are generated server-side from the + owned_by_me flag and are not directly usable by LLM callers. """ mock_filters.return_value = { "table_name": ["eq", "ilike"], @@ -410,7 +427,7 @@ class TestGetSchemaToolViaClient: info = data["schema_info"] assert "table_name" in info["filter_columns"] - for field in ("created_by_fk", "owner", "created_by_fk_or_owner"): + for field in ("owner", "created_by_fk_or_owner"): assert field not in info["filter_columns"] @patch("superset.daos.dashboard.DashboardDAO.get_filterable_columns_and_operators") @@ -420,8 +437,9 @@ class TestGetSchemaToolViaClient: ): """Test dashboard schema omits self-referencing filter columns. - Even if the DAO returns created_by_fk or owner, they must be excluded - so LLMs cannot discover and use them to enumerate user IDs. + Even if the DAO returns owner or created_by_fk_or_owner, they must be + excluded — these synthetic columns are generated server-side from the + owned_by_me flag and are not directly usable by LLM callers. """ mock_filters.return_value = { "dashboard_title": ["eq", "ilike"], @@ -439,7 +457,7 @@ class TestGetSchemaToolViaClient: info = data["schema_info"] assert "dashboard_title" in info["filter_columns"] - for field in ("created_by_fk", "owner", "created_by_fk_or_owner"): + for field in ("owner", "created_by_fk_or_owner"): assert field not in info["filter_columns"]
