This is an automated email from the ASF dual-hosted git repository.
arivero pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 97a66f7a647 feat(mcp): add BM25 tool search transform to reduce
initial context size (#38562)
97a66f7a647 is described below
commit 97a66f7a6474f79be36e4cfd16b6f6df938c1e5b
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Fri Mar 13 18:06:11 2026 +0100
feat(mcp): add BM25 tool search transform to reduce initial context size
(#38562)
Co-authored-by: Claude Opus 4.6 <[email protected]>
---
pyproject.toml | 2 +-
requirements/development.txt | 61 ++--
superset/mcp_service/chart/schemas.py | 323 ++++++---------------
superset/mcp_service/common/cache_schemas.py | 35 +--
superset/mcp_service/dashboard/schemas.py | 64 ++--
superset/mcp_service/flask_singleton.py | 70 ++---
superset/mcp_service/mcp_config.py | 34 +++
superset/mcp_service/server.py | 141 ++++++++-
superset/mcp_service/system/schemas.py | 100 +++----
.../mcp_service/test_mcp_tool_registration.py | 25 +-
.../mcp_service/test_tool_search_transform.py | 173 +++++++++++
11 files changed, 552 insertions(+), 476 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 70feeeb4d3e..12d5224cb38 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -144,7 +144,7 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
-fastmcp = ["fastmcp==2.14.3"]
+fastmcp = ["fastmcp>=3.1.0,<4.0"]
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
gevent = ["gevent>=23.9.1"]
diff --git a/requirements/development.txt b/requirements/development.txt
index 993bbfdc3b4..296fe15d820 100644
--- a/requirements/development.txt
+++ b/requirements/development.txt
@@ -10,6 +10,8 @@
# via
# -r requirements/development.in
# apache-superset
+aiofile==3.9.0
+ # via py-key-value-aio
alembic==1.15.2
# via
# -c requirements/base-constraint.txt
@@ -26,8 +28,10 @@ anyio==4.11.0
# via
# httpx
# mcp
+ # py-key-value-aio
# sse-starlette
# starlette
+ # watchfiles
apispec==6.6.1
# via
# -c requirements/base-constraint.txt
@@ -65,9 +69,7 @@ bcrypt==4.3.0
# -c requirements/base-constraint.txt
# paramiko
beartype==0.22.5
- # via
- # py-key-value-aio
- # py-key-value-shared
+ # via py-key-value-aio
billiard==4.2.1
# via
# -c requirements/base-constraint.txt
@@ -100,6 +102,8 @@ cachetools==6.2.1
# -c requirements/base-constraint.txt
# google-auth
# py-key-value-aio
+caio==0.9.25
+ # via aiofile
cattrs==25.1.1
# via
# -c requirements/base-constraint.txt
@@ -138,7 +142,6 @@ click==8.2.1
# click-repl
# flask
# flask-appbuilder
- # typer
# uvicorn
click-didyoumean==0.3.1
# via
@@ -156,8 +159,6 @@ click-repl==0.3.0
# via
# -c requirements/base-constraint.txt
# celery
-cloudpickle==3.1.2
- # via pydocket
cmdstanpy==1.1.0
# via prophet
colorama==0.4.6
@@ -206,8 +207,6 @@ deprecation==2.1.0
# apache-superset
dill==0.4.0
# via pylint
-diskcache==5.6.3
- # via py-key-value-aio
distlib==0.3.8
# via virtualenv
dnspython==2.7.0
@@ -237,9 +236,7 @@ et-xmlfile==2.0.0
# openpyxl
exceptiongroup==1.3.0
# via fastmcp
-fakeredis==2.32.1
- # via pydocket
-fastmcp==2.14.3
+fastmcp==3.1.0
# via apache-superset
filelock==3.20.3
# via
@@ -474,6 +471,8 @@ jsonpath-ng==1.7.0
# via
# -c requirements/base-constraint.txt
# apache-superset
+jsonref==1.1.0
+ # via fastmcp
jsonschema==4.23.0
# via
# -c requirements/base-constraint.txt
@@ -504,8 +503,6 @@ limits==5.1.0
# via
# -c requirements/base-constraint.txt
# flask-limiter
-lupa==2.6
- # via fakeredis
mako==1.3.10
# via
# -c requirements/base-constraint.txt
@@ -603,7 +600,7 @@ openpyxl==3.1.5
# -c requirements/base-constraint.txt
# pandas
opentelemetry-api==1.39.1
- # via pydocket
+ # via fastmcp
ordered-set==4.1.0
# via
# -c requirements/base-constraint.txt
@@ -622,6 +619,7 @@ packaging==25.0
# deprecation
# docker
# duckdb-engine
+ # fastmcp
# google-cloud-bigquery
# gunicorn
# limits
@@ -653,8 +651,6 @@ parsedatetime==2.6
# apache-superset
pathable==0.4.3
# via jsonschema-path
-pathvalidate==3.3.1
- # via py-key-value-aio
pgsanity==0.2.9
# via
# -c requirements/base-constraint.txt
@@ -691,8 +687,6 @@ prison==0.2.1
# flask-appbuilder
progress==1.6
# via apache-superset
-prometheus-client==0.23.1
- # via pydocket
prompt-toolkit==3.0.51
# via
# -c requirements/base-constraint.txt
@@ -714,12 +708,8 @@ psutil==6.1.0
# via apache-superset
psycopg2-binary==2.9.9
# via apache-superset
-py-key-value-aio==0.3.0
- # via
- # fastmcp
- # pydocket
-py-key-value-shared==0.3.0
- # via py-key-value-aio
+py-key-value-aio==0.4.4
+ # via fastmcp
pyarrow==16.1.0
# via
# -c requirements/base-constraint.txt
@@ -758,8 +748,6 @@ pydantic-settings==2.10.1
# via mcp
pydata-google-auth==1.9.0
# via pandas-gbq
-pydocket==0.17.1
- # via fastmcp
pydruid==0.6.9
# via apache-superset
pyfakefs==5.3.5
@@ -844,8 +832,6 @@ python-dotenv==1.1.0
# apache-superset
# fastmcp
# pydantic-settings
-python-json-logger==4.0.0
- # via pydocket
python-ldap==3.4.4
# via apache-superset
python-multipart==0.0.20
@@ -866,15 +852,13 @@ pyyaml==6.0.2
# -c requirements/base-constraint.txt
# apache-superset
# apispec
+ # fastmcp
# jsonschema-path
# pre-commit
redis==5.3.1
# via
# -c requirements/base-constraint.txt
# apache-superset
- # fakeredis
- # py-key-value-aio
- # pydocket
referencing==0.36.2
# via
# -c requirements/base-constraint.txt
@@ -910,9 +894,7 @@ rich==13.9.4
# cyclopts
# fastmcp
# flask-limiter
- # pydocket
# rich-rst
- # typer
rich-rst==1.3.1
# via cyclopts
rpds-py==0.25.0
@@ -944,8 +926,6 @@ setuptools==80.9.0
# pydata-google-auth
# zope-event
# zope-interface
-shellingham==1.5.4
- # via typer
shillelagh==1.4.3
# via
# -c requirements/base-constraint.txt
@@ -973,7 +953,6 @@ sniffio==1.3.1
sortedcontainers==2.4.0
# via
# -c requirements/base-constraint.txt
- # fakeredis
# trio
sqlalchemy==1.4.54
# via
@@ -1034,8 +1013,6 @@ trio-websocket==0.12.2
# via
# -c requirements/base-constraint.txt
# selenium
-typer==0.20.0
- # via pydocket
typing-extensions==4.15.0
# via
# -c requirements/base-constraint.txt
@@ -1048,16 +1025,14 @@ typing-extensions==4.15.0
# limits
# mcp
# opentelemetry-api
- # py-key-value-shared
+ # py-key-value-aio
# pydantic
# pydantic-core
- # pydocket
# pyopenssl
# referencing
# selenium
# shillelagh
# starlette
- # typer
# typing-inspection
typing-inspection==0.4.1
# via
@@ -1072,6 +1047,8 @@ tzdata==2025.2
# pandas
tzlocal==5.2
# via trino
+uncalled-for==0.2.0
+ # via fastmcp
url-normalize==2.2.1
# via
# -c requirements/base-constraint.txt
@@ -1101,6 +1078,8 @@ watchdog==6.0.0
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-extensions-cli
+watchfiles==1.1.1
+ # via fastmcp
wcwidth==0.2.13
# via
# -c requirements/base-constraint.txt
diff --git a/superset/mcp_service/chart/schemas.py
b/superset/mcp_service/chart/schemas.py
index 4cc2bffa07f..373f667e586 100644
--- a/superset/mcp_service/chart/schemas.py
+++ b/superset/mcp_service/chart/schemas.py
@@ -384,15 +384,12 @@ class ChartList(BaseModel):
class ColumnRef(BaseModel):
name: str = Field(
...,
- description="Column name",
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
- label: str | None = Field(
- None, description="Display label for the column", max_length=500
- )
- dtype: str | None = Field(None, description="Data type hint")
+ label: str | None = Field(None, max_length=500)
+ dtype: str | None = None
aggregate: (
Literal[
"SUM",
@@ -407,11 +404,7 @@ class ColumnRef(BaseModel):
"PERCENTILE",
]
| None
- ) = Field(
- None,
- description="SQL aggregation function. Only these validated functions
are "
- "supported to prevent SQL errors.",
- )
+ ) = Field(None, description="SQL aggregate function")
@field_validator("name")
@classmethod
@@ -431,25 +424,22 @@ class ColumnRef(BaseModel):
class AxisConfig(BaseModel):
- title: str | None = Field(None, description="Axis title", max_length=200)
- scale: Literal["linear", "log"] | None = Field(
- "linear", description="Axis scale type"
- )
- format: str | None = Field(
- None, description="Format string (e.g. '$,.2f')", max_length=50
- )
+ title: str | None = Field(None, max_length=200)
+ scale: Literal["linear", "log"] | None = "linear"
+ format: str | None = Field(None, description="e.g. '$,.2f'", max_length=50)
class LegendConfig(BaseModel):
- show: bool = Field(True, description="Whether to show legend")
- position: Literal["top", "bottom", "left", "right"] | None = Field(
- "right", description="Legend position"
- )
+ show: bool = True
+ position: Literal["top", "bottom", "left", "right"] | None = "right"
class FilterConfig(BaseModel):
column: str = Field(
- ..., description="Column to filter on", min_length=1, max_length=255
+ ...,
+ min_length=1,
+ max_length=255,
+ pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
op: Literal[
"=",
@@ -465,17 +455,11 @@ class FilterConfig(BaseModel):
"NOT IN",
] = Field(
...,
- description=(
- "Filter operator. Use LIKE/ILIKE for pattern matching with %
wildcards "
- "(e.g., '%mario%'). Use IN/NOT IN with a list of values."
- ),
+ description="LIKE/ILIKE use % wildcards. IN/NOT IN take a list.",
)
value: str | int | float | bool | list[str | int | float | bool] = Field(
...,
- description=(
- "Filter value. For IN/NOT IN operators, provide a list of values. "
- "For LIKE/ILIKE, use % as wildcard (e.g., '%mario%')."
- ),
+ description="For IN/NOT IN, provide a list.",
)
@field_validator("column")
@@ -516,26 +500,13 @@ class FilterConfig(BaseModel):
class PieChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
- chart_type: Literal["pie"] = Field(
- ...,
- description=(
- "Chart type discriminator - MUST be 'pie' for pie/donut charts. "
- "This field is REQUIRED and tells Superset which chart "
- "configuration schema to use."
- ),
- )
- dimension: ColumnRef = Field(
- ..., description="Category column that defines the pie slices"
- )
+ chart_type: Literal["pie"] = "pie"
+ dimension: ColumnRef = Field(..., description="Category column for slices")
metric: ColumnRef = Field(
- ...,
- description=(
- "Value metric that determines slice sizes. "
- "Must include an aggregate function (e.g., SUM, COUNT)."
- ),
+ ..., description="Value metric (needs aggregate e.g. SUM, COUNT)"
)
- donut: bool = Field(False, description="Render as a donut chart with a
center hole")
- show_labels: bool = Field(True, description="Display labels on slices")
+ donut: bool = False
+ show_labels: bool = True
label_type: Literal[
"key",
"value",
@@ -544,63 +515,32 @@ class PieChartConfig(BaseModel):
"key_percent",
"key_value_percent",
"value_percent",
- ] = Field("key_value_percent", description="Type of labels to show on
slices")
- sort_by_metric: bool = Field(True, description="Sort slices by metric
value")
- show_legend: bool = Field(True, description="Whether to show legend")
- filters: List[FilterConfig] | None = Field(None, description="Filters to
apply")
- row_limit: int = Field(
- 100,
- description="Maximum number of slices to display",
- ge=1,
- le=10000,
- )
- number_format: str = Field(
- "SMART_NUMBER",
- description="Number format string",
- max_length=50,
- )
- show_total: bool = Field(False, description="Display aggregate count in
center")
- labels_outside: bool = Field(True, description="Place labels outside the
pie")
- outer_radius: int = Field(
- 70,
- description="Outer edge radius as a percentage (1-100)",
- ge=1,
- le=100,
- )
+ ] = "key_value_percent"
+ sort_by_metric: bool = True
+ show_legend: bool = True
+ filters: List[FilterConfig] | None = None
+ row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
+ number_format: str = Field("SMART_NUMBER", max_length=50)
+ show_total: bool = Field(False, description="Show total in center")
+ labels_outside: bool = True
+ outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1,
le=100)
inner_radius: int = Field(
- 30,
- description="Inner radius as a percentage for donut (1-100)",
- ge=1,
- le=100,
+ 30, description="Donut inner radius % (1-100)", ge=1, le=100
)
class PivotTableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
- chart_type: Literal["pivot_table"] = Field(
- ...,
- description=(
- "Chart type discriminator - MUST be 'pivot_table' for interactive "
- "pivot tables. This field is REQUIRED."
- ),
- )
- rows: List[ColumnRef] = Field(
- ...,
- min_length=1,
- description="Row grouping columns (at least one required)",
- )
+ chart_type: Literal["pivot_table"] = "pivot_table"
+ rows: List[ColumnRef] = Field(..., min_length=1, description="Row grouping
columns")
columns: List[ColumnRef] | None = Field(
- None,
- description="Column grouping columns (optional, for cross-tabulation)",
+ None, description="Column groups for cross-tabulation"
)
metrics: List[ColumnRef] = Field(
...,
min_length=1,
- description=(
- "Metrics to aggregate. Each must have an aggregate function "
- "(e.g., SUM, COUNT, AVG)."
- ),
+ description="Metrics (need aggregate e.g. SUM, COUNT, AVG)",
)
aggregate_function: Literal[
"Sum",
@@ -614,108 +554,56 @@ class PivotTableChartConfig(BaseModel):
"Count Unique Values",
"First",
"Last",
- ] = Field("Sum", description="Default aggregation function for the pivot
table")
- show_row_totals: bool = Field(True, description="Show row totals")
- show_column_totals: bool = Field(True, description="Show column totals")
- transpose: bool = Field(False, description="Swap rows and columns")
- combine_metric: bool = Field(
- False,
- description="Display metrics side by side within columns",
- )
- filters: List[FilterConfig] | None = Field(None, description="Filters to
apply")
- row_limit: int = Field(
- 10000,
- description="Maximum number of cells",
- ge=1,
- le=50000,
- )
- value_format: str = Field(
- "SMART_NUMBER",
- description="Value format string",
- max_length=50,
- )
+ ] = "Sum"
+ show_row_totals: bool = True
+ show_column_totals: bool = True
+ transpose: bool = False
+ combine_metric: bool = Field(False, description="Metrics side by side in
columns")
+ filters: List[FilterConfig] | None = None
+ row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
+ value_format: str = Field("SMART_NUMBER", max_length=50)
class MixedTimeseriesChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
- chart_type: Literal["mixed_timeseries"] = Field(
- ...,
- description=(
- "Chart type discriminator - MUST be 'mixed_timeseries' for charts "
- "that combine two different series types (e.g., line + bar). "
- "This field is REQUIRED."
- ),
- )
- x: ColumnRef = Field(..., description="X-axis temporal column (shared)")
- time_grain: TimeGrain | None = Field(
- None,
- description=(
- "Time granularity for the x-axis. "
- "Common values: PT1H (hourly), P1D (daily), P1W (weekly), "
- "P1M (monthly), P1Y (yearly)."
- ),
- )
+ chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
+ x: ColumnRef = Field(..., description="Shared temporal X-axis column")
+ time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W,
P1M, P1Y")
# Primary series (Query A)
- y: List[ColumnRef] = Field(
- ...,
- min_length=1,
- description="Primary Y-axis metrics (Query A)",
- )
- primary_kind: Literal["line", "bar", "area", "scatter"] = Field(
- "line", description="Primary series chart type"
- )
- group_by: ColumnRef | None = Field(
- None, description="Group by column for primary series"
- )
+ y: List[ColumnRef] = Field(..., min_length=1, description="Primary Y-axis
metrics")
+ primary_kind: Literal["line", "bar", "area", "scatter"] = "line"
+ group_by: ColumnRef | None = Field(None, description="Primary series group
by")
# Secondary series (Query B)
y_secondary: List[ColumnRef] = Field(
- ...,
- min_length=1,
- description="Secondary Y-axis metrics (Query B)",
- )
- secondary_kind: Literal["line", "bar", "area", "scatter"] = Field(
- "bar", description="Secondary series chart type"
+ ..., min_length=1, description="Secondary Y-axis metrics"
)
+ secondary_kind: Literal["line", "bar", "area", "scatter"] = "bar"
group_by_secondary: ColumnRef | None = Field(
- None, description="Group by column for secondary series"
+ None, description="Secondary series group by"
)
# Display options
- show_legend: bool = Field(True, description="Whether to show legend")
- x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
- y_axis: AxisConfig | None = Field(None, description="Primary Y-axis
configuration")
- y_axis_secondary: AxisConfig | None = Field(
- None, description="Secondary Y-axis configuration"
- )
- filters: List[FilterConfig] | None = Field(None, description="Filters to
apply")
+ show_legend: bool = True
+ x_axis: AxisConfig | None = None
+ y_axis: AxisConfig | None = None
+ y_axis_secondary: AxisConfig | None = None
+ filters: List[FilterConfig] | None = None
class TableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
- chart_type: Literal["table"] = Field(
- ..., description="Chart type (REQUIRED: must be 'table')"
- )
+ chart_type: Literal["table"] = "table"
viz_type: Literal["table", "ag-grid-table"] = Field(
- "table",
- description=(
- "Visualization type: 'table' for standard table, 'ag-grid-table'
for "
- "AG Grid Interactive Table with advanced features like column
resizing, "
- "sorting, filtering, and server-side pagination"
- ),
+ "table", description="'ag-grid-table' for interactive features"
)
columns: List[ColumnRef] = Field(
...,
min_length=1,
- description=(
- "Columns to display. Must have at least one column. Each column
must have "
- "a unique label "
- "(either explicitly set via 'label' field or auto-generated "
- "from name/aggregate)"
- ),
+ description="Columns with unique labels",
)
- filters: List[FilterConfig] | None = Field(None, description="Filters to
apply")
- sort_by: List[str] | None = Field(None, description="Columns to sort by")
+ filters: List[FilterConfig] | None = None
+ sort_by: List[str] | None = None
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
@@ -748,56 +636,26 @@ class TableChartConfig(BaseModel):
class XYChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
- chart_type: Literal["xy"] = Field(
- ...,
- description=(
- "Chart type discriminator - MUST be 'xy' for XY charts "
- "(line, bar, area, scatter). "
- "This field is REQUIRED and tells Superset which chart "
- "configuration schema to use."
- ),
- )
+ chart_type: Literal["xy"] = "xy"
x: ColumnRef = Field(..., description="X-axis column")
y: List[ColumnRef] = Field(
- ...,
- min_length=1,
- description="Y-axis columns (metrics). Must have at least one Y-axis
column. "
- "Each column must have a unique label "
- "that doesn't conflict with x-axis or group_by labels",
- )
- kind: Literal["line", "bar", "area", "scatter"] = Field(
- "line", description="Chart visualization type"
+ ..., min_length=1, description="Y-axis metrics (unique labels)"
)
+ kind: Literal["line", "bar", "area", "scatter"] = "line"
time_grain: TimeGrain | None = Field(
- None,
- description=(
- "Time granularity for the x-axis when it's a temporal column. "
- "Common values: PT1S (second), PT1M (minute), PT1H (hour), "
- "P1D (day), P1W (week), P1M (month), P3M (quarter), P1Y (year). "
- "If not specified, Superset will use its default behavior."
- ),
+ None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y"
)
orientation: Literal["vertical", "horizontal"] | None = Field(
- None,
- description=(
- "Bar chart orientation. Only applies when kind='bar'. "
- "'vertical' (default): bars extend upward. "
- "'horizontal': bars extend rightward, useful for long category
names."
- ),
- )
- stacked: bool = Field(
- False,
- description="Stack bars/areas on top of each other instead of
side-by-side",
+ None, description="Bar orientation (only for kind='bar')"
)
+ stacked: bool = False
group_by: ColumnRef | None = Field(
- None,
- description="Column to group by (creates series/breakdown). "
- "Use this field for series grouping — do NOT use 'series'.",
+ None, description="Series breakdown column (not 'series')"
)
- x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
- y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
- legend: LegendConfig | None = Field(None, description="Legend
configuration")
- filters: List[FilterConfig] | None = Field(None, description="Filters to
apply")
+ x_axis: AxisConfig | None = None
+ y_axis: AxisConfig | None = None
+ legend: LegendConfig | None = None
+ filters: List[FilterConfig] | None = None
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "XYChartConfig":
@@ -949,21 +807,12 @@ class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID,
UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
chart_name: str | None = Field(
- None,
- description="Custom chart name (optional, auto-generates if not
provided)",
- max_length=255,
- )
- save_chart: bool = Field(
- default=False,
- description="Whether to permanently save the chart in Superset",
- )
- generate_preview: bool = Field(
- default=True,
- description="Whether to generate a preview image",
+ None, description="Auto-generates if omitted", max_length=255
)
+ save_chart: bool = Field(default=False, description="Save permanently in
Superset")
+ generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] =
Field(
default_factory=lambda: ["url"],
- description="List of preview formats to generate",
)
@field_validator("chart_name")
@@ -1002,20 +851,14 @@ class GenerateExploreLinkRequest(FormDataCacheControl):
class UpdateChartRequest(QueryCacheControl):
- identifier: int | str = Field(..., description="Chart identifier (ID,
UUID)")
- config: ChartConfig = Field(..., description="New chart configuration")
+ identifier: int | str = Field(..., description="Chart ID or UUID")
+ config: ChartConfig
chart_name: str | None = Field(
- None,
- description="New chart name (optional, will auto-generate if not
provided)",
- max_length=255,
- )
- generate_preview: bool = Field(
- default=True,
- description="Whether to generate a preview after updating",
+ None, description="Auto-generates if omitted", max_length=255
)
+ generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] =
Field(
default_factory=lambda: ["url"],
- description="List of preview formats to generate",
)
@field_validator("chart_name")
@@ -1027,15 +870,11 @@ class UpdateChartRequest(QueryCacheControl):
class UpdateChartPreviewRequest(FormDataCacheControl):
form_data_key: str = Field(..., description="Existing form_data_key to
update")
- dataset_id: int | str = Field(..., description="Dataset identifier (ID,
UUID)")
- config: ChartConfig = Field(..., description="New chart configuration")
- generate_preview: bool = Field(
- default=True,
- description="Whether to generate a preview after updating",
- )
+ dataset_id: int | str = Field(..., description="Dataset ID or UUID")
+ config: ChartConfig
+ generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] =
Field(
default_factory=lambda: ["url"],
- description="List of preview formats to generate",
)
diff --git a/superset/mcp_service/common/cache_schemas.py
b/superset/mcp_service/common/cache_schemas.py
index 59b3112b29f..bf51bbf49b0 100644
--- a/superset/mcp_service/common/cache_schemas.py
+++ b/superset/mcp_service/common/cache_schemas.py
@@ -37,22 +37,10 @@ class CacheControlMixin(BaseModel):
- Dashboard Cache: Caches rendered dashboard components
"""
- use_cache: bool = Field(
- default=True,
- description=(
- "Whether to use Superset's cache layers. When True, will serve
from "
- "cache if available (query results, metadata, form data). When
False, "
- "will bypass cache and fetch fresh data."
- ),
- )
+ use_cache: bool = Field(default=True, description="Use cache if available")
force_refresh: bool = Field(
- default=False,
- description=(
- "Whether to force refresh cached data. When True, will invalidate "
- "existing cache entries and fetch fresh data, then update the
cache. "
- "Overrides use_cache=True if both are specified."
- ),
+ default=False, description="Invalidate cache and fetch fresh data"
)
@@ -65,12 +53,7 @@ class QueryCacheControl(CacheControlMixin):
"""
cache_timeout: int | None = Field(
- default=None,
- description=(
- "Override the default cache timeout in seconds for this query. "
- "If not specified, uses dataset-level or global cache settings. "
- "Set to 0 to disable caching for this specific query."
- ),
+ default=None, description="Cache timeout override in seconds (0 to
disable)"
)
@@ -83,11 +66,7 @@ class MetadataCacheControl(CacheControlMixin):
"""
refresh_metadata: bool = Field(
- default=False,
- description=(
- "Whether to refresh metadata cache for datasets, tables, and
columns. "
- "Useful when database schema has changed and you need fresh
metadata."
- ),
+ default=False, description="Refresh metadata cache for schema changes"
)
@@ -100,11 +79,7 @@ class FormDataCacheControl(CacheControlMixin):
"""
cache_form_data: bool = Field(
- default=True,
- description=(
- "Whether to cache the form data configuration for future use. "
- "When False, generates temporary configurations that are not
cached."
- ),
+ default=True, description="Cache form data for future use"
)
diff --git a/superset/mcp_service/dashboard/schemas.py
b/superset/mcp_service/dashboard/schemas.py
index 48aa72a952c..ba56cf774b0 100644
--- a/superset/mcp_service/dashboard/schemas.py
+++ b/superset/mcp_service/dashboard/schemas.py
@@ -287,45 +287,31 @@ class GetDashboardInfoRequest(MetadataCacheControl):
class DashboardInfo(BaseModel):
- id: int | None = Field(None, description="Dashboard ID")
- dashboard_title: str | None = Field(None, description="Dashboard title")
- slug: str | None = Field(None, description="Dashboard slug")
- description: str | None = Field(None, description="Dashboard description")
- css: str | None = Field(None, description="Custom CSS for the dashboard")
- certified_by: str | None = Field(None, description="Who certified the
dashboard")
- certification_details: str | None = Field(None, description="Certification
details")
- json_metadata: str | None = Field(
- None, description="Dashboard metadata (JSON string)"
- )
- position_json: str | None = Field(None, description="Chart positions (JSON
string)")
- published: bool | None = Field(
- None, description="Whether the dashboard is published"
- )
- is_managed_externally: bool | None = Field(
- None, description="Whether managed externally"
- )
- external_url: str | None = Field(None, description="External URL")
- created_on: str | datetime | None = Field(None, description="Creation
timestamp")
- changed_on: str | datetime | None = Field(
- None, description="Last modification timestamp"
- )
- created_by: str | None = Field(None, description="Dashboard creator
(username)")
- changed_by: str | None = Field(None, description="Last modifier
(username)")
- uuid: str | None = Field(None, description="Dashboard UUID (converted to
string)")
- url: str | None = Field(None, description="Dashboard URL")
- created_on_humanized: str | None = Field(
- None, description="Humanized creation time"
- )
- changed_on_humanized: str | None = Field(
- None, description="Humanized modification time"
- )
- chart_count: int = Field(0, description="Number of charts in the
dashboard")
- owners: List[UserInfo] = Field(default_factory=list,
description="Dashboard owners")
- tags: List[TagInfo] = Field(default_factory=list, description="Dashboard
tags")
- roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard
roles")
- charts: List[ChartInfo] = Field(
- default_factory=list, description="Dashboard charts"
- )
+ id: int | None = None
+ dashboard_title: str | None = None
+ slug: str | None = None
+ description: str | None = None
+ css: str | None = None
+ certified_by: str | None = None
+ certification_details: str | None = None
+ json_metadata: str | None = None
+ position_json: str | None = None
+ published: bool | None = None
+ is_managed_externally: bool | None = None
+ external_url: str | None = None
+ created_on: str | datetime | None = None
+ changed_on: str | datetime | None = None
+ created_by: str | None = None
+ changed_by: str | None = None
+ uuid: str | None = None
+ url: str | None = None
+ created_on_humanized: str | None = None
+ changed_on_humanized: str | None = None
+ chart_count: int = 0
+ owners: List[UserInfo] = Field(default_factory=list)
+ tags: List[TagInfo] = Field(default_factory=list)
+ roles: List[RoleInfo] = Field(default_factory=list)
+ charts: List[ChartInfo] = Field(default_factory=list)
# Fields for permalink/filter state support
permalink_key: str | None = Field(
diff --git a/superset/mcp_service/flask_singleton.py
b/superset/mcp_service/flask_singleton.py
index 3c35d31a6a4..d3a7124ec13 100644
--- a/superset/mcp_service/flask_singleton.py
+++ b/superset/mcp_service/flask_singleton.py
@@ -25,7 +25,6 @@ Following the Stack Overflow recommendation:
"""
import logging
-import os
from flask import current_app, Flask, has_app_context
@@ -52,62 +51,45 @@ try:
logger.info("Reusing existing Flask app from app context for MCP
service")
# Use _get_current_object() to get the actual Flask app, not the
LocalProxy
app = current_app._get_current_object()
+ elif appbuilder_initialized:
+ # appbuilder is initialized but we have no app context. Calling
+ # create_app() here would invoke appbuilder.init_app() a second
+ # time with a *different* Flask app, overwriting shared internal
+ # state (views, security manager, etc.). Fail loudly instead of
+ # silently corrupting the singleton.
+ raise RuntimeError(
+ "appbuilder is already initialized but no Flask app context is "
+ "available. Cannot call create_app() as it would re-initialize "
+ "appbuilder with a different Flask app instance."
+ )
else:
- # Either appbuilder is not initialized (standalone MCP server),
- # or appbuilder is initialized but we're not in an app context
- # (edge case - should rarely happen). In both cases, create a minimal
app.
+ # Standalone MCP server — Superset models are deeply coupled to
+ # appbuilder, security_manager, event_logger, encrypted_field_factory,
+ # etc. so we use create_app() for full initialization rather than
+ # trying to init a minimal subset (which leads to cascading failures).
#
- # We avoid calling create_app() which would run full FAB initialization
- # and could corrupt the shared appbuilder singleton if main app starts.
- from superset.app import SupersetApp
+ # create_app() is safe here because in standalone mode the main
+ # Superset web server is not running in-process.
+ from superset.app import create_app
from superset.mcp_service.mcp_config import get_mcp_config
- if appbuilder_initialized:
- logger.warning(
- "Appbuilder initialized but not in app context - "
- "creating separate MCP Flask app"
- )
- else:
- logger.info("Creating minimal Flask app for standalone MCP
service")
-
- # Disable debug mode to avoid side-effects like file watchers
- _mcp_app = SupersetApp(__name__)
+ logger.info("Creating fully initialized Flask app for standalone MCP
service")
+ _mcp_app = create_app()
_mcp_app.debug = False
- # Load configuration
- config_module = os.environ.get("SUPERSET_CONFIG", "superset.config")
- _mcp_app.config.from_object(config_module)
-
- # Apply MCP-specific configuration
+ # Apply MCP-specific configuration on top
mcp_config = get_mcp_config(_mcp_app.config)
_mcp_app.config.update(mcp_config)
- # Initialize only the minimal dependencies needed for MCP service
with _mcp_app.app_context():
- try:
- from superset.extensions import db
-
- db.init_app(_mcp_app)
-
- # Initialize only MCP-specific dependencies
- # MCP tools import directly from superset.daos/models, so we
only need
- # the MCP decorator injection, not the full superset_core
abstraction
- from superset.core.mcp.core_mcp_injection import (
- initialize_core_mcp_dependencies,
- )
-
- initialize_core_mcp_dependencies()
+ from superset.core.mcp.core_mcp_injection import (
+ initialize_core_mcp_dependencies,
+ )
- logger.info(
- "Minimal MCP dependencies initialized for standalone MCP
service"
- )
- except Exception as e:
- logger.warning(
- "Failed to initialize dependencies for MCP service: %s", e
- )
+ initialize_core_mcp_dependencies()
app = _mcp_app
- logger.info("Minimal Flask app instance created successfully for MCP
service")
+ logger.info("Flask app fully initialized for standalone MCP service")
except Exception as e:
logger.error("Failed to create Flask app: %s", e)
diff --git a/superset/mcp_service/mcp_config.py
b/superset/mcp_service/mcp_config.py
index 75221f854ef..e1e65c11f0b 100644
--- a/superset/mcp_service/mcp_config.py
+++ b/superset/mcp_service/mcp_config.py
@@ -227,10 +227,44 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
"get_chart_preview", # Returns URLs, not data
"generate_explore_link", # Returns URLs
"open_sql_lab_with_context", # Returns URLs
+ "search_tools", # Returns tool schemas for discovery (intentionally
large)
],
}
+# =============================================================================
+# MCP Tool Search Transform Configuration
+# =============================================================================
+#
+# Overview:
+# ---------
+# When enabled, replaces the full tool catalog with a search interface.
+# LLMs see only 2 synthetic tools (search_tools + call_tool) plus any
+# pinned tools, and discover other tools on-demand via natural language search.
+# This reduces initial context by ~70% (from ~40k tokens to ~5-8k tokens).
+#
+# Strategies:
+# -----------
+# - "bm25": Natural language search using BM25 ranking (recommended)
+# - "regex": Pattern-based search using regular expressions
+#
+# Rollback:
+# ---------
+# Set enabled=False in superset_config.py for instant rollback.
+# =============================================================================
+MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
+ "enabled": True, # Enabled by default — reduces initial context by ~70%
+ "strategy": "bm25", # "bm25" (natural language) or "regex" (pattern
matching)
+ "max_results": 5, # Max tools returned per search
+ "always_visible": [ # Tools always shown in list_tools (pinned)
+ "health_check",
+ "get_instance_info",
+ ],
+ "search_tool_name": "search_tools", # Name of the search tool
+ "call_tool_name": "call_tool", # Name of the call proxy tool
+}
+
+
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
"""Default MCP auth factory using app.config values."""
if not app.config.get("MCP_AUTH_ENABLED", False):
diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py
index 5435a66e68b..732e8f4622f 100644
--- a/superset/mcp_service/server.py
+++ b/superset/mcp_service/server.py
@@ -24,12 +24,17 @@ For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG
with Redis URL.
import logging
import os
+from collections.abc import Sequence
from typing import Any
import uvicorn
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
-from superset.mcp_service.mcp_config import get_mcp_factory_config,
MCP_STORE_CONFIG
+from superset.mcp_service.mcp_config import (
+ get_mcp_factory_config,
+ MCP_STORE_CONFIG,
+ MCP_TOOL_SEARCH_CONFIG,
+)
from superset.mcp_service.middleware import (
create_response_size_guard_middleware,
GlobalErrorHandlerMiddleware,
@@ -111,8 +116,7 @@ def create_event_store(config: dict[str, Any] | None =
None) -> Any | None:
if config is None:
config = MCP_STORE_CONFIG
- redis_url = config.get("CACHE_REDIS_URL")
- if not redis_url:
+ if not config.get("CACHE_REDIS_URL"):
logging.info("EventStore: Using in-memory storage (single-pod mode)")
return None
@@ -151,6 +155,117 @@ def create_event_store(config: dict[str, Any] | None =
None) -> Any | None:
return None
+def _strip_titles(obj: Any, in_properties_map: bool = False) -> Any:
+ """Recursively strip schema metadata ``title`` keys.
+
+ Keeps real field names inside ``properties`` (e.g. a property literally
+ named ``title``), while removing auto-generated schema title metadata.
+ """
+ if isinstance(obj, dict):
+ result: dict[str, Any] = {}
+ for key, value in obj.items():
+ if key == "title" and not in_properties_map:
+ continue
+ result[key] = _strip_titles(value, in_properties_map=(key ==
"properties"))
+ return result
+ if isinstance(obj, list):
+ return [_strip_titles(item, in_properties_map=False) for item in obj]
+ return obj
+
+
+def _serialize_tools_without_output_schema(
+ tools: Sequence[Any],
+) -> list[dict[str, Any]]:
+ """Serialize tools to JSON, stripping outputSchema and titles to reduce
tokens.
+
+ LLMs only need inputSchema to call tools. outputSchema accounts for
+ 50-80% of the per-tool schema size, and auto-generated 'title' fields
+ add ~12% bloat. Stripping both cuts search result tokens significantly.
+ """
+ results = []
+ for tool in tools:
+ data = tool.to_mcp_tool().model_dump(mode="json", exclude_none=True)
+ data.pop("outputSchema", None)
+ if input_schema := data.get("inputSchema"):
+ data["inputSchema"] = _strip_titles(input_schema)
+ results.append(data)
+ return results
+
+
+def _fix_call_tool_arguments(tool: Any) -> Any:
+ """Fix anyOf schema in call_tool ``arguments`` for MCP bridge
compatibility.
+
+ FastMCP's BaseSearchTransform defines ``arguments`` as
+ ``dict[str, Any] | None`` which emits an ``anyOf`` JSON Schema.
+ Some MCP bridges (mcp-remote, Claude Desktop) don't handle ``anyOf``
+ and strip it, leaving the field without a ``type`` — causing all
+ call_tool invocations to fail with "Input should be a valid dictionary".
+
+ Replaces the ``anyOf`` with a flat ``type: object``.
+ """
+ if "arguments" in (props := (tool.parameters or {}).get("properties", {})):
+ props["arguments"] = {
+ "additionalProperties": True,
+ "default": None,
+ "description": "Arguments to pass to the tool",
+ "type": "object",
+ }
+ return tool
+
+
+def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) ->
None:
+ """Apply tool search transform to reduce initial context size.
+
+ When enabled, replaces the full tool catalog with a search interface.
+ LLMs see only synthetic search/call tools plus pinned tools, and
+ discover other tools on-demand via natural language search.
+
+ Uses subclassing (not monkey-patching) to override ``_make_call_tool``
+ and fix the ``arguments`` schema for MCP bridge compatibility.
+
+ NOTE: ``_make_call_tool`` is a private API in FastMCP 3.x
+ (fastmcp>=3.1.0,<4.0). If FastMCP changes or removes this method
+ in a future major version, these subclasses will need to be updated.
+ """
+ strategy = config.get("strategy", "bm25")
+ kwargs: dict[str, Any] = {
+ "max_results": config.get("max_results", 5),
+ "always_visible": config.get("always_visible", []),
+ "search_tool_name": config.get("search_tool_name", "search_tools"),
+ "call_tool_name": config.get("call_tool_name", "call_tool"),
+ "search_result_serializer": _serialize_tools_without_output_schema,
+ }
+
+ if strategy == "regex":
+ from fastmcp.server.transforms.search import RegexSearchTransform
+
+ class _FixedRegexSearchTransform(RegexSearchTransform):
+ """Regex search with fixed call_tool arguments schema."""
+
+ def _make_call_tool(self) -> Any:
+ return _fix_call_tool_arguments(super()._make_call_tool())
+
+ transform = _FixedRegexSearchTransform(**kwargs)
+ else:
+ from fastmcp.server.transforms.search import BM25SearchTransform
+
+ class _FixedBM25SearchTransform(BM25SearchTransform):
+ """BM25 search with fixed call_tool arguments schema."""
+
+ def _make_call_tool(self) -> Any:
+ return _fix_call_tool_arguments(super()._make_call_tool())
+
+ transform = _FixedBM25SearchTransform(**kwargs)
+
+ mcp_instance.add_transform(transform)
+ logger.info(
+ "Tool search transform enabled (strategy=%s, max_results=%d,
pinned=%s)",
+ strategy,
+ kwargs["max_results"],
+ kwargs["always_visible"],
+ )
+
+
def _create_auth_provider(flask_app: Any) -> Any | None:
"""Create an auth provider from Flask app config.
@@ -218,6 +333,11 @@ def run_server(
logging.info("Creating MCP app from factory configuration...")
factory_config = get_mcp_factory_config()
mcp_instance = create_mcp_app(**factory_config)
+
+ # Apply tool search transform if configured
+ tool_search_config = MCP_TOOL_SEARCH_CONFIG
+ if tool_search_config.get("enabled", False):
+ _apply_tool_search_transform(mcp_instance, tool_search_config)
else:
# Use default initialization with auth from Flask config
logging.info("Creating MCP app with default configuration...")
@@ -233,8 +353,7 @@ def run_server(
middleware_list = []
# Add caching middleware (innermost – runs closest to the tool)
- caching_middleware = create_response_caching_middleware()
- if caching_middleware:
+ if caching_middleware := create_response_caching_middleware():
middleware_list.append(caching_middleware)
# Add response size guard (protects LLM clients from huge responses)
@@ -252,6 +371,18 @@ def run_server(
middleware=middleware_list or None,
)
+ # Apply tool search transform if configured
+ tool_search_config = flask_app.config.get(
+ "MCP_TOOL_SEARCH_CONFIG", MCP_TOOL_SEARCH_CONFIG
+ )
+ if tool_search_config.get("enabled", False):
+ _apply_tool_search_transform(mcp_instance, tool_search_config)
+ # Ensure the configured search tool name is excluded from the
+ # response size guard (search results are intentionally large)
+ if size_guard_middleware:
+ search_name = tool_search_config.get("search_tool_name",
"search_tools")
+ size_guard_middleware.excluded_tools.add(search_name)
+
# Create EventStore for session management (Redis for multi-pod, None for
in-memory)
event_store = create_event_store(event_store_config)
diff --git a/superset/mcp_service/system/schemas.py
b/superset/mcp_service/system/schemas.py
index 9810cc4d3b3..ae0fab91e62 100644
--- a/superset/mcp_service/system/schemas.py
+++ b/superset/mcp_service/system/schemas.py
@@ -58,54 +58,40 @@ class GetSupersetInstanceInfoRequest(BaseModel):
class InstanceSummary(BaseModel):
- total_dashboards: int = Field(..., description="Total number of
dashboards")
- total_charts: int = Field(..., description="Total number of charts")
- total_datasets: int = Field(..., description="Total number of datasets")
- total_databases: int = Field(..., description="Total number of databases")
- total_users: int = Field(..., description="Total number of users")
- total_roles: int = Field(..., description="Total number of roles")
- total_tags: int = Field(..., description="Total number of tags")
- avg_charts_per_dashboard: float = Field(
- ..., description="Average number of charts per dashboard"
- )
+ total_dashboards: int
+ total_charts: int
+ total_datasets: int
+ total_databases: int
+ total_users: int
+ total_roles: int
+ total_tags: int
+ avg_charts_per_dashboard: float
class RecentActivity(BaseModel):
- dashboards_created_last_30_days: int = Field(
- ..., description="Dashboards created in the last 30 days"
- )
- charts_created_last_30_days: int = Field(
- ..., description="Charts created in the last 30 days"
- )
- datasets_created_last_30_days: int = Field(
- ..., description="Datasets created in the last 30 days"
- )
- dashboards_modified_last_7_days: int = Field(
- ..., description="Dashboards modified in the last 7 days"
- )
- charts_modified_last_7_days: int = Field(
- ..., description="Charts modified in the last 7 days"
- )
- datasets_modified_last_7_days: int = Field(
- ..., description="Datasets modified in the last 7 days"
- )
+ dashboards_created_last_30_days: int
+ charts_created_last_30_days: int
+ datasets_created_last_30_days: int
+ dashboards_modified_last_7_days: int
+ charts_modified_last_7_days: int
+ datasets_modified_last_7_days: int
class DashboardBreakdown(BaseModel):
- published: int = Field(..., description="Number of published dashboards")
- unpublished: int = Field(..., description="Number of unpublished
dashboards")
- certified: int = Field(..., description="Number of certified dashboards")
- with_charts: int = Field(..., description="Number of dashboards with
charts")
- without_charts: int = Field(..., description="Number of dashboards without
charts")
+ published: int
+ unpublished: int
+ certified: int
+ with_charts: int
+ without_charts: int
class DatabaseBreakdown(BaseModel):
- by_type: Dict[str, int] = Field(..., description="Breakdown of databases
by type")
+ by_type: Dict[str, int]
class PopularContent(BaseModel):
- top_tags: List[str] = Field(..., description="Most popular tags")
- top_creators: List[str] = Field(..., description="Most active creators")
+ top_tags: List[str] = Field(default_factory=list)
+ top_creators: List[str] = Field(default_factory=list)
class FeatureAvailability(BaseModel):
@@ -125,33 +111,19 @@ class FeatureAvailability(BaseModel):
class InstanceInfo(BaseModel):
- instance_summary: InstanceSummary = Field(
- ..., description="Instance summary information"
- )
- recent_activity: RecentActivity = Field(
- ..., description="Recent activity information"
- )
- dashboard_breakdown: DashboardBreakdown = Field(
- ..., description="Dashboard breakdown information"
- )
- database_breakdown: DatabaseBreakdown = Field(
- ..., description="Database breakdown by type"
- )
- popular_content: PopularContent = Field(
- ..., description="Popular content information"
- )
+ instance_summary: InstanceSummary
+ recent_activity: RecentActivity
+ dashboard_breakdown: DashboardBreakdown
+ database_breakdown: DatabaseBreakdown
+ popular_content: PopularContent
current_user: UserInfo | None = Field(
None,
- description="The authenticated user making the request. "
- "Use current_user.id with created_by_fk filter to find your own
assets.",
- )
- feature_availability: FeatureAvailability = Field(
- ...,
description=(
- "Dynamic feature availability for the current user and deployment"
+ "Use current_user.id with created_by_fk filter to find your own
assets."
),
)
- timestamp: datetime = Field(..., description="Response timestamp")
+ feature_availability: FeatureAvailability
+ timestamp: datetime
class UserInfo(BaseModel):
@@ -207,10 +179,10 @@ class RoleInfo(BaseModel):
class PaginationInfo(BaseModel):
- page: int = Field(..., description="Current page number")
- page_size: int = Field(..., description="Number of items per page")
- total_count: int = Field(..., description="Total number of items")
- total_pages: int = Field(..., description="Total number of pages")
- has_next: bool = Field(..., description="Whether there is a next page")
- has_previous: bool = Field(..., description="Whether there is a previous
page")
+ page: int
+ page_size: int
+ total_count: int
+ total_pages: int
+ has_next: bool
+ has_previous: bool
model_config = ConfigDict(ser_json_timedelta="iso8601")
diff --git a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py
b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py
index be470bc4fd0..9307a671659 100644
--- a/tests/unit_tests/mcp_service/test_mcp_tool_registration.py
+++ b/tests/unit_tests/mcp_service/test_mcp_tool_registration.py
@@ -17,25 +17,32 @@
"""Test MCP app imports and tool/prompt registration."""
+import asyncio
+
+
+def _run(coro):
+ """Run an async coroutine synchronously."""
+ return asyncio.run(coro)
+
def test_mcp_app_imports_successfully():
"""Test that the MCP app can be imported without errors."""
from superset.mcp_service.app import mcp
assert mcp is not None
- assert hasattr(mcp, "_tool_manager")
- tools = mcp._tool_manager._tools
- assert len(tools) > 0
- assert "health_check" in tools
- assert "list_charts" in tools
+ tools = _run(mcp.list_tools())
+ tool_names = [t.name for t in tools]
+ assert len(tool_names) > 0
+ assert "health_check" in tool_names
+ assert "list_charts" in tool_names
def test_mcp_prompts_registered():
"""Test that MCP prompts are registered."""
from superset.mcp_service.app import mcp
- prompts = mcp._prompt_manager._prompts
+ prompts = _run(mcp.list_prompts())
assert len(prompts) > 0
@@ -48,12 +55,10 @@ def test_mcp_resources_registered():
"""
from superset.mcp_service.app import mcp
- resource_manager = mcp._resource_manager
- resources = resource_manager._resources
+ resources = _run(mcp.list_resources())
assert len(resources) > 0, "No MCP resources registered"
- # Verify the two documented resources are registered
- resource_uris = set(resources.keys())
+ resource_uris = {str(r.uri) for r in resources}
assert "chart://configs" in resource_uris, (
"chart://configs resource not registered - "
"check superset/mcp_service/chart/__init__.py exists"
diff --git a/tests/unit_tests/mcp_service/test_tool_search_transform.py
b/tests/unit_tests/mcp_service/test_tool_search_transform.py
new file mode 100644
index 00000000000..7adfeacad47
--- /dev/null
+++ b/tests/unit_tests/mcp_service/test_tool_search_transform.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Tests for MCP tool search transform configuration and application."""
+
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+from fastmcp.server.transforms.search import BM25SearchTransform,
RegexSearchTransform
+
+from superset.mcp_service.mcp_config import MCP_TOOL_SEARCH_CONFIG
+from superset.mcp_service.server import (
+ _apply_tool_search_transform,
+ _fix_call_tool_arguments,
+ _serialize_tools_without_output_schema,
+)
+
+
+def test_tool_search_config_defaults():
+ """Default config has expected keys and values."""
+ assert MCP_TOOL_SEARCH_CONFIG["enabled"] is True
+ assert MCP_TOOL_SEARCH_CONFIG["strategy"] == "bm25"
+ assert MCP_TOOL_SEARCH_CONFIG["max_results"] == 5
+ assert "health_check" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
+ assert "get_instance_info" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
+ assert MCP_TOOL_SEARCH_CONFIG["search_tool_name"] == "search_tools"
+ assert MCP_TOOL_SEARCH_CONFIG["call_tool_name"] == "call_tool"
+
+
+def test_apply_bm25_transform():
+ """BM25 subclass is created and added when strategy is 'bm25'."""
+ mock_mcp = MagicMock()
+ config = {
+ "strategy": "bm25",
+ "max_results": 5,
+ "always_visible": ["health_check"],
+ "search_tool_name": "search_tools",
+ "call_tool_name": "call_tool",
+ }
+
+ _apply_tool_search_transform(mock_mcp, config)
+
+ mock_mcp.add_transform.assert_called_once()
+ transform = mock_mcp.add_transform.call_args[0][0]
+ assert isinstance(transform, BM25SearchTransform)
+
+
+def test_apply_regex_transform():
+ """Regex subclass is created and added when strategy is 'regex'."""
+ mock_mcp = MagicMock()
+ config = {
+ "strategy": "regex",
+ "max_results": 10,
+ "always_visible": ["health_check", "get_instance_info"],
+ "search_tool_name": "find_tools",
+ "call_tool_name": "invoke_tool",
+ }
+
+ _apply_tool_search_transform(mock_mcp, config)
+
+ mock_mcp.add_transform.assert_called_once()
+ transform = mock_mcp.add_transform.call_args[0][0]
+ assert isinstance(transform, RegexSearchTransform)
+
+
+def test_apply_transform_uses_defaults_for_missing_keys():
+ """Missing config keys fall back to sensible defaults (BM25)."""
+ mock_mcp = MagicMock()
+ config = {} # All keys missing — should use defaults
+
+ _apply_tool_search_transform(mock_mcp, config)
+
+ mock_mcp.add_transform.assert_called_once()
+ transform = mock_mcp.add_transform.call_args[0][0]
+ assert isinstance(transform, BM25SearchTransform)
+
+
+def test_fix_call_tool_arguments_replaces_anyof():
+ """_fix_call_tool_arguments replaces anyOf with flat type: object."""
+ tool = SimpleNamespace(
+ parameters={
+ "type": "object",
+ "properties": {
+ "name": {"type": "string"},
+ "arguments": {
+ "anyOf": [
+ {"type": "object", "additionalProperties": True},
+ {"type": "null"},
+ ],
+ "default": None,
+ },
+ },
+ }
+ )
+
+ result = _fix_call_tool_arguments(tool)
+
+ assert result.parameters["properties"]["arguments"] == {
+ "additionalProperties": True,
+ "default": None,
+ "description": "Arguments to pass to the tool",
+ "type": "object",
+ }
+ # Other properties untouched
+ assert result.parameters["properties"]["name"] == {"type": "string"}
+
+
+def test_fix_call_tool_arguments_no_arguments_field():
+ """_fix_call_tool_arguments is a no-op when arguments field is absent."""
+ tool = SimpleNamespace(
+ parameters={
+ "type": "object",
+ "properties": {"name": {"type": "string"}},
+ }
+ )
+
+ result = _fix_call_tool_arguments(tool)
+
+ assert "arguments" not in result.parameters["properties"]
+
+
+def test_serialize_tools_strips_output_schema():
+ """Custom serializer removes outputSchema from tool definitions."""
+ mock_tool = MagicMock()
+ mock_mcp_tool = MagicMock()
+ mock_mcp_tool.model_dump.return_value = {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {"type": "object", "properties": {"x": {"type":
"integer"}}},
+ "outputSchema": {
+ "type": "object",
+ "properties": {"result": {"type": "string"}},
+ },
+ }
+ mock_tool.to_mcp_tool.return_value = mock_mcp_tool
+
+ result = _serialize_tools_without_output_schema([mock_tool])
+
+ assert len(result) == 1
+ assert result[0]["name"] == "test_tool"
+ assert "inputSchema" in result[0]
+ assert "outputSchema" not in result[0]
+
+
+def test_serialize_tools_handles_no_output_schema():
+ """Custom serializer works when tool has no outputSchema."""
+ mock_tool = MagicMock()
+ mock_mcp_tool = MagicMock()
+ mock_mcp_tool.model_dump.return_value = {
+ "name": "simple_tool",
+ "inputSchema": {"type": "object"},
+ }
+ mock_tool.to_mcp_tool.return_value = mock_mcp_tool
+
+ result = _serialize_tools_without_output_schema([mock_tool])
+
+ assert len(result) == 1
+ assert result[0]["name"] == "simple_tool"
+ assert "outputSchema" not in result[0]