This is an automated email from the ASF dual-hosted git repository.
richardfogaca pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 8fa5a75c704 fix(mcp): apply cached adhoc filters to chart retrieval
(#40099)
8fa5a75c704 is described below
commit 8fa5a75c704950439abc986c1239db16bc223d42
Author: Richard Fogaca Nienkotter
<[email protected]>
AuthorDate: Thu May 14 14:21:54 2026 -0300
fix(mcp): apply cached adhoc filters to chart retrieval (#40099)
---
superset/mcp_service/chart/chart_helpers.py | 451 ++++++++++++++++++++-
superset/mcp_service/chart/tool/get_chart_data.py | 236 ++---------
.../mcp_service/chart/tool/get_chart_preview.py | 105 +----
superset/mcp_service/chart/tool/get_chart_sql.py | 241 +----------
.../mcp_service/chart/test_chart_helpers.py | 179 ++++++++
.../mcp_service/chart/tool/test_get_chart_data.py | 179 ++++++++
.../chart/tool/test_get_chart_preview.py | 386 ++++++++++++++++++
.../mcp_service/chart/tool/test_get_chart_sql.py | 55 +++
8 files changed, 1324 insertions(+), 508 deletions(-)
diff --git a/superset/mcp_service/chart/chart_helpers.py
b/superset/mcp_service/chart/chart_helpers.py
index 05477e76eee..b6f2b5b30dc 100644
--- a/superset/mcp_service/chart/chart_helpers.py
+++ b/superset/mcp_service/chart/chart_helpers.py
@@ -26,14 +26,23 @@ URL parameter extraction. Config mapping logic lives in
chart_utils.py.
from __future__ import annotations
import logging
-from typing import TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from urllib.parse import parse_qs, urlparse
+from superset.constants import EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS
+
if TYPE_CHECKING:
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
+QUERY_CONTEXT_EXTRA_FORM_DATA_OVERRIDE_KEYS = {
+ "granularity",
+ "time_grain",
+ "time_grain_sqla",
+ "time_range",
+}
+
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
"""Find a chart by numeric ID or UUID string.
@@ -69,6 +78,446 @@ def get_cached_form_data(form_data_key: str) -> str | None:
return None
+def resolve_datasource_engine(datasource_id: Any, datasource_type: str) -> str:
+ """Return the datasource engine name, or ``"base"`` if it cannot be
resolved."""
+ if not isinstance(datasource_id, (int, str)):
+ return "base"
+ try:
+ # avoid circular import
+ from superset.daos.datasource import DatasourceDAO
+ from superset.utils.core import DatasourceType
+
+ datasource = DatasourceDAO.get_datasource(
+ datasource_type=DatasourceType(datasource_type),
+ database_id_or_uuid=datasource_id,
+ )
+ return datasource.database.db_engine_spec.engine
+ except Exception: # noqa: BLE001
+ # Engine lookup is best-effort; fall back to generic filter
normalization.
+ logger.debug("Could not resolve engine for datasource %s",
datasource_id)
+ return "base"
+
+
+def prepare_form_data_for_query(
+ form_data: dict[str, Any],
+ datasource_id: Any,
+ datasource_type: str,
+ extra_form_data: dict[str, Any] | None = None,
+ datasource_engine: str | None = None,
+) -> None:
+ """Normalize form_data filters before building a QueryObject payload.
+
+ Explore and legacy viz query construction merge dashboard/native filter
payloads
+ and split adhoc filters into the concrete ``filters``/``where``/``having``
+ fields consumed by QueryObject. MCP tools that build query payloads
directly
+ must perform the same normalization before calling QueryContextFactory.
+
+ Mutates ``form_data`` in place.
+ """
+ # avoid circular import
+ from superset.utils.core import (
+ convert_legacy_filters_into_adhoc,
+ form_data_to_adhoc,
+ merge_extra_filters,
+ simple_filter_to_adhoc,
+ split_adhoc_filters_into_base_filters,
+ )
+
+ if isinstance(form_data.get("adhoc_filters"), list):
+ adhoc_filters = [
+ *(
+ form_data_to_adhoc(form_data, clause)
+ for clause in ("having", "where")
+ if form_data.get(clause)
+ ),
+ *(
+ simple_filter_to_adhoc(filter_, "where")
+ for filter_ in form_data.get("filters") or []
+ if filter_ is not None
+ ),
+ *form_data["adhoc_filters"],
+ ]
+ form_data["adhoc_filters"] = adhoc_filters
+
+ if extra_form_data:
+ form_data["extra_form_data"] = merge_extra_form_data(
+ form_data.get("extra_form_data"),
+ extra_form_data,
+ )
+ convert_legacy_filters_into_adhoc(form_data)
+ merge_extra_filters(form_data)
+ split_adhoc_filters_into_base_filters(
+ form_data,
+ datasource_engine or resolve_datasource_engine(datasource_id,
datasource_type),
+ )
+
+
+def merge_extra_form_data(
+ existing: Any,
+ incoming: dict[str, Any],
+) -> dict[str, Any]:
+ """Merge cached and request-level extra_form_data payloads."""
+ merged: dict[str, Any] = dict(existing) if isinstance(existing, dict) else
{}
+ for key, value in incoming.items():
+ current = merged.get(key)
+ if isinstance(current, list) and isinstance(value, list):
+ merged[key] = [*current, *value]
+ elif isinstance(current, dict) and isinstance(value, dict):
+ merged[key] = {**current, **value}
+ else:
+ merged[key] = value
+ return merged
+
+
+def apply_form_data_filters_to_query(
+ query: dict[str, Any],
+ form_data: dict[str, Any],
+) -> None:
+ """Copy normalized form_data filter fields into a fresh query payload."""
+ if filters := form_data.get("filters"):
+ query["filters"] = filters
+ else:
+ query.setdefault("filters", [])
+
+ if time_range := form_data.get("time_range"):
+ query["time_range"] = time_range
+ if where := form_data.get("where"):
+ query["where"] = where
+ if having := form_data.get("having"):
+ query["having"] = having
+
+
+def _join_sql_clause(existing_clause: str, additional_clause: str) -> str:
+ """AND two SQL filter clauses while preserving their original grouping."""
+ return f"({existing_clause}) AND ({additional_clause})"
+
+
+def _is_temporal_override_filter(
+ filter_: dict[str, Any],
+ form_data: dict[str, Any],
+) -> bool:
+ return (
+ filter_.get("op") == "TEMPORAL_RANGE"
+ and form_data.get("time_range") is not None
+ and filter_.get("val") == form_data.get("time_range")
+ and (
+ form_data.get("granularity") is None
+ or filter_.get("col") == form_data.get("granularity")
+ )
+ )
+
+
+def merge_form_data_filters_into_query(
+ query: dict[str, Any],
+ form_data: dict[str, Any],
+) -> None:
+ """Merge normalized form_data filters into an existing query payload.
+
+ Saved query contexts can contain query-specific filter, where, or having
+ fields. This helper adds normalized predicates while applying request-level
+ extra_form_data overrides for temporal query fields.
+ """
+ if filters := [
+ filter_
+ for filter_ in form_data.get("filters") or []
+ if not _is_temporal_override_filter(filter_, form_data)
+ ]:
+ query["filters"] = [
+ *(query.get("filters") or []),
+ *filters,
+ ]
+
+ for key in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.values():
+ if (
+ key in QUERY_CONTEXT_EXTRA_FORM_DATA_OVERRIDE_KEYS
+ and key in form_data
+ and form_data[key] is not None
+ ):
+ query[key] = form_data[key]
+
+ for clause in ("where", "having"):
+ if additional_clause := form_data.get(clause):
+ if existing_clause := query.get(clause):
+ query[clause] = _join_sql_clause(existing_clause,
additional_clause)
+ else:
+ query[clause] = additional_clause
+
+
+def merge_extra_form_data_filters_into_query(
+ query: dict[str, Any],
+ extra_form_data: dict[str, Any],
+ datasource_id: Any,
+ datasource_type: str,
+) -> None:
+ """Merge request extra_form_data predicates into an existing query
payload."""
+ extra_query_form_data: dict[str, Any] = {"adhoc_filters": []}
+ prepare_form_data_for_query(
+ extra_query_form_data,
+ datasource_id,
+ datasource_type,
+ extra_form_data,
+ )
+ merge_form_data_filters_into_query(query, extra_query_form_data)
+
+
+def resolve_metrics(form_data: dict[str, Any], viz_type: str) -> list[Any]:
+ """Extract metrics from form_data, handling chart-type-specific fields."""
+ if viz_type == "bubble":
+ return [m for field in ("x", "y", "size") if (m :=
form_data.get(field))]
+
+ metrics = form_data.get("metrics", [])
+ if not metrics and (metric := form_data.get("metric")):
+ metrics = [metric]
+ return metrics
+
+
+def resolve_groupby(form_data: dict[str, Any]) -> list[Any]:
+ """Extract groupby columns from form_data with fallback aliases."""
+ raw_columns = form_data.get("all_columns")
+ if form_data.get("query_mode") == "raw" and isinstance(raw_columns, list):
+ return list(raw_columns)
+
+ raw_groupby = form_data.get("groupby") or []
+ if isinstance(raw_groupby, str):
+ groupby: list[Any] = [raw_groupby]
+ else:
+ groupby = list(raw_groupby)
+
+ if groupby:
+ return groupby
+
+ for field in ("entity", "series"):
+ value = form_data.get(field)
+ if isinstance(value, str) and value not in groupby:
+ groupby.append(value)
+
+ form_columns = form_data.get("columns")
+ if isinstance(form_columns, list):
+ for col in form_columns:
+ if isinstance(col, str) and col not in groupby:
+ groupby.append(col)
+
+ if not groupby and isinstance(raw_columns, list):
+ groupby.extend(raw_columns)
+
+ return groupby
+
+
+def resolve_metrics_and_groupby(
+ form_data: dict[str, Any],
+ chart: Any | None = None,
+) -> tuple[list[Any], list[Any]]:
+ """Resolve metrics and groupby columns from form_data."""
+ viz_type = form_data.get(
+ "viz_type", getattr(chart, "viz_type", "") if chart else ""
+ )
+ singular_metric_no_groupby = (
+ "big_number",
+ "big_number_total",
+ "pop_kpi",
+ )
+ if viz_type in singular_metric_no_groupby:
+ metrics: list[Any] = [metric] if (metric := form_data.get("metric"))
else []
+ return metrics, []
+
+ return resolve_metrics(form_data, viz_type), resolve_groupby(form_data)
+
+
+def extract_x_axis_col(form_data: dict[str, Any]) -> str | None:
+ """Return the x_axis column name from form_data, or None if not set."""
+ x_axis = form_data.get("x_axis")
+ if isinstance(x_axis, str) and x_axis:
+ return x_axis
+ if isinstance(x_axis, dict):
+ col_name = x_axis.get("column_name")
+ return col_name if isinstance(col_name, str) and col_name else None
+ return None
+
+
+def _build_single_query_dict(
+ form_data: dict[str, Any],
+ columns: list[Any],
+ metrics: list[Any],
+ row_limit: int | None = None,
+ order_desc: bool | None = None,
+) -> dict[str, Any]:
+ """Build one query entry for QueryContextFactory from form_data fields."""
+ qd: dict[str, Any] = {"columns": columns, "metrics": metrics}
+ effective_row_limit = row_limit
+ if effective_row_limit is None:
+ effective_row_limit = form_data.get("row_limit")
+ if effective_row_limit is not None:
+ qd["row_limit"] = effective_row_limit
+ if order_desc is not None:
+ qd["order_desc"] = order_desc
+ apply_form_data_filters_to_query(qd, form_data)
+ return qd
+
+
+def _build_mixed_timeseries_secondary(
+ form_data: dict[str, Any],
+ x_axis_col: str | None,
+ engine: str,
+ row_limit: int | None = None,
+ order_desc: bool | None = None,
+) -> dict[str, Any]:
+ """Build the secondary query dict for the ``mixed_timeseries`` viz type."""
+ # avoid circular import
+ from superset.utils.core import split_adhoc_filters_into_base_filters
+
+ metrics_b: list[Any] = list(form_data.get("metrics_b") or [])
+ raw_b = form_data.get("groupby_b") or []
+ groupby_b: list[Any] = [raw_b] if isinstance(raw_b, str) else list(raw_b)
+ if x_axis_col and x_axis_col not in groupby_b:
+ groupby_b = [x_axis_col] + groupby_b
+
+ qd = _build_single_query_dict(
+ form_data,
+ groupby_b,
+ metrics_b,
+ row_limit=row_limit,
+ order_desc=order_desc,
+ )
+ if time_range_b := form_data.get("time_range_b"):
+ qd["time_range"] = time_range_b
+ if row_limit is None and (row_limit_b := form_data.get("row_limit_b")) is
not None:
+ qd["row_limit"] = row_limit_b
+
+ if adhoc_filters_b := form_data.get("adhoc_filters_b"):
+ secondary_fd: dict[str, Any] = {"adhoc_filters": adhoc_filters_b}
+ split_adhoc_filters_into_base_filters(secondary_fd, engine)
+ if secondary_filters := secondary_fd.get("filters"):
+ qd["filters"] = secondary_filters
+ else:
+ qd.pop("filters", None)
+ for clause in ("where", "having"):
+ if secondary_clause := secondary_fd.get(clause):
+ qd[clause] = secondary_clause
+ else:
+ qd.pop(clause, None)
+ return qd
+
+
+def build_query_dicts_from_form_data(
+ form_data: dict[str, Any],
+ datasource_id: Any,
+ datasource_type: str,
+ chart: Any | None = None,
+ extra_form_data: dict[str, Any] | None = None,
+ row_limit: int | None = None,
+ order_desc: bool | None = None,
+) -> list[dict[str, Any]]:
+ """Build chart-type-aware query dicts from Explore form_data."""
+ engine = resolve_datasource_engine(datasource_id, datasource_type)
+ prepare_form_data_for_query(
+ form_data,
+ datasource_id,
+ datasource_type,
+ extra_form_data,
+ datasource_engine=engine,
+ )
+
+ metrics, groupby = resolve_metrics_and_groupby(form_data, chart)
+ viz_type: str = (
+ form_data.get("viz_type")
+ or (getattr(chart, "viz_type", "") if chart else "")
+ or ""
+ )
+ is_timeseries = (
+ viz_type.startswith("echarts_timeseries") or viz_type ==
"mixed_timeseries"
+ )
+
+ x_axis_col: str | None = None
+ if is_timeseries:
+ x_axis_col = extract_x_axis_col(form_data)
+ if x_axis_col and x_axis_col not in groupby:
+ groupby = [x_axis_col] + groupby
+
+ queries = [
+ _build_single_query_dict(
+ form_data,
+ groupby,
+ metrics,
+ row_limit=row_limit,
+ order_desc=order_desc,
+ )
+ ]
+ if viz_type == "mixed_timeseries":
+ queries.append(
+ _build_mixed_timeseries_secondary(
+ form_data,
+ x_axis_col,
+ engine,
+ row_limit=row_limit,
+ order_desc=order_desc,
+ )
+ )
+ return queries
+
+
+def resolve_form_data_datasource(
+ form_data: dict[str, Any],
+ chart: Any | None = None,
+) -> tuple[int | str | None, str]:
+ """Resolve datasource id/type from form_data with chart fallbacks."""
+ datasource_id = form_data.get("datasource_id")
+ datasource_type = form_data.get("datasource_type")
+
+ if not datasource_id and (combined := form_data.get("datasource")):
+ if isinstance(combined, str) and "__" in combined:
+ parts = combined.split("__", 1)
+ datasource_id = int(parts[0]) if parts[0].isdigit() else parts[0]
+ datasource_type = parts[1] if len(parts) > 1 else None
+
+ if not datasource_id and chart:
+ datasource_id = getattr(chart, "datasource_id", None)
+ if not datasource_type and chart:
+ datasource_type = getattr(chart, "datasource_type", None)
+
+ return datasource_id, datasource_type if isinstance(
+ datasource_type, str
+ ) else "table"
+
+
+def build_query_context_from_form_data(
+ form_data: dict[str, Any],
+ chart: Any | None = None,
+ extra_form_data: dict[str, Any] | None = None,
+ row_limit: int | None = None,
+ order_desc: bool | None = None,
+ result_type: Any = None,
+ force: bool = False,
+) -> Any:
+ """Build a QueryContext from chart-type-aware Explore form_data."""
+ # avoid circular import
+ from superset.common.query_context_factory import QueryContextFactory
+
+ datasource_id, datasource_type = resolve_form_data_datasource(form_data,
chart)
+ if not isinstance(datasource_id, (int, str)):
+ raise ValueError(
+ "Cannot determine datasource ID from form_data. "
+ "Provide a chart identifier or ensure form_data contains "
+ "'datasource_id' or 'datasource'."
+ )
+
+ queries = build_query_dicts_from_form_data(
+ form_data,
+ datasource_id,
+ datasource_type,
+ chart=chart,
+ extra_form_data=extra_form_data,
+ row_limit=row_limit,
+ order_desc=order_desc,
+ )
+ return QueryContextFactory().create(
+ datasource={"id": datasource_id, "type": datasource_type},
+ queries=queries,
+ form_data=form_data,
+ result_type=result_type,
+ force=force,
+ )
+
+
def extract_form_data_key_from_url(url: str | None) -> str | None:
"""Extract the form_data_key query parameter from an explore URL.
diff --git a/superset/mcp_service/chart/tool/get_chart_data.py
b/superset/mcp_service/chart/tool/get_chart_data.py
index a14fdcb73cb..5b6c0b3a5a5 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -35,8 +35,11 @@ from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError,
SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import (
+ build_query_context_from_form_data,
+ build_query_dicts_from_form_data,
find_chart_by_identifier,
get_cached_form_data,
+ merge_extra_form_data_filters_into_query,
)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
@@ -55,7 +58,6 @@ from superset.mcp_service.utils.oauth2_utils import (
build_oauth2_redirect_message,
OAUTH2_CONFIG_ERROR_MESSAGE,
)
-from superset.utils.core import merge_extra_filters
logger = logging.getLogger(__name__)
@@ -94,16 +96,6 @@ def _sanitize_chart_data_for_llm_context(chart_data:
ChartData) -> ChartData:
return ChartData.model_validate(payload)
-def _apply_extra_form_data(
- form_data: dict[str, Any], extra_form_data: dict[str, Any] | None
-) -> None:
- """Merge dashboard native filters into chart form_data in-place."""
- if not extra_form_data:
- return
- form_data["extra_form_data"] = extra_form_data
- merge_extra_filters(form_data)
-
-
@tool(
tags=["data"],
class_permission_name="Chart",
@@ -293,65 +285,18 @@ async def get_chart_data( # noqa: C901
# If using cached form_data, we need to build query_context from it
if using_unsaved_state and cached_form_data_dict is not None:
# Build query context from cached form_data (unsaved state)
- from superset.common.query_context_factory import
QueryContextFactory
-
- factory = QueryContextFactory()
row_limit = (
request.limit
or cached_form_data_dict.get("row_limit")
or current_app.config["ROW_LIMIT"]
)
- # Get datasource info from cached form_data or fall back to
chart
- datasource_id = cached_form_data_dict.get(
- "datasource_id", chart.datasource_id
- )
- datasource_type = cached_form_data_dict.get(
- "datasource_type", chart.datasource_type
- )
-
- # Handle different chart types that have different form_data
- # structures. Some charts use "metric" (singular), not
"metrics"
- # (plural): big_number, big_number_total, pop_kpi.
- # These charts also don't have groupby columns.
- cached_viz_type = cached_form_data_dict.get(
- "viz_type", chart.viz_type or ""
- )
- if cached_viz_type in ("big_number", "big_number_total",
"pop_kpi"):
- metric = cached_form_data_dict.get("metric")
- cached_metrics = [metric] if metric else []
- cached_groupby: list[str] = []
- else:
- cached_metrics = cached_form_data_dict.get("metrics", [])
- raw_groupby = cached_form_data_dict.get("groupby", [])
- # Guard against string groupby (e.g. heatmap_v2 migrated
- # from legacy heatmap where all_columns_y was a string)
- if isinstance(raw_groupby, str):
- cached_groupby = [raw_groupby]
- else:
- cached_groupby = list(raw_groupby)
-
- _apply_extra_form_data(cached_form_data_dict,
request.extra_form_data)
-
- cached_query: dict[str, Any] = {
- "filters": cached_form_data_dict.get("filters", []),
- "columns": cached_groupby,
- "metrics": cached_metrics,
- "row_limit": row_limit,
- "order_desc": cached_form_data_dict.get("order_desc",
True),
- }
- # Include adhoc_filters so dashboard native filters are applied
- cached_adhoc = cached_form_data_dict.get("adhoc_filters")
- if cached_adhoc:
- cached_query["adhoc_filters"] = cached_adhoc
-
- query_context = factory.create(
- datasource={
- "id": datasource_id,
- "type": datasource_type,
- },
- queries=[cached_query],
- form_data=cached_form_data_dict,
+ query_context = build_query_context_from_form_data(
+ cached_form_data_dict,
+ chart=chart,
+ extra_form_data=request.extra_form_data,
+ row_limit=row_limit,
+ order_desc=cached_form_data_dict.get("order_desc", True),
force=request.force_refresh,
)
await ctx.debug(
@@ -420,102 +365,23 @@ async def get_chart_data( # noqa: C901
error_type="MissingQueryContext",
)
- singular_metric_no_groupby = (
- "big_number",
- "big_number_total",
- "pop_kpi",
- )
- singular_metric_types = (
- *singular_metric_no_groupby,
- "world_map",
- "treemap_v2",
- "sunburst_v2",
- "gauge_chart",
+ fallback_queries = build_query_dicts_from_form_data(
+ form_data,
+ chart.datasource_id,
+ chart.datasource_type,
+ chart=chart,
+ extra_form_data=request.extra_form_data,
+ row_limit=row_limit,
+ order_desc=True,
)
- if viz_type == "bubble":
- # Bubble charts store metrics in x, y, size fields
- bubble_metrics = []
- for field in ("x", "y", "size"):
- m = form_data.get(field)
- if m:
- bubble_metrics.append(m)
- metrics = bubble_metrics
- groupby_columns: list[str] = list(
- form_data.get("entity", None) and
[form_data["entity"]] or []
- )
- series_field = form_data.get("series")
- if series_field and series_field not in groupby_columns:
- groupby_columns.append(series_field)
- elif viz_type in singular_metric_types:
- # These chart types use "metric" (singular)
- metric = form_data.get("metric")
- metrics = [metric] if metric else []
- if viz_type in singular_metric_no_groupby:
- groupby_columns = []
- else:
- # Some singular-metric charts use groupby, entity,
- # series, or columns for dimensional breakdown
- groupby_columns = list(form_data.get("groupby") or [])
- entity = form_data.get("entity")
- if entity and entity not in groupby_columns:
- groupby_columns.append(entity)
- series = form_data.get("series")
- if series and series not in groupby_columns:
- groupby_columns.append(series)
- form_columns = form_data.get("columns")
- if form_columns and isinstance(form_columns, list):
- for col in form_columns:
- if isinstance(col, str) and col not in
groupby_columns:
- groupby_columns.append(col)
- else:
- # Standard charts use "metrics" (plural) and "groupby"
- metrics = form_data.get("metrics", [])
- raw_groupby = form_data.get("groupby") or []
- # Guard against string groupby (e.g. heatmap_v2 migrated
- # from legacy heatmap where all_columns_y was a string)
- if isinstance(raw_groupby, str):
- groupby_columns = [raw_groupby]
- else:
- groupby_columns = list(raw_groupby)
- # Some chart types use "columns" instead of "groupby"
- if not groupby_columns:
- form_columns = form_data.get("columns")
- if form_columns and isinstance(form_columns, list):
- for col in form_columns:
- if isinstance(col, str):
- groupby_columns.append(col)
-
- # Fallback: if metrics is still empty, try singular "metric"
- if not metrics:
- fallback_metric = form_data.get("metric")
- if fallback_metric:
- metrics = [fallback_metric]
-
- # Fallback: try entity/series if groupby is still empty
- if not groupby_columns:
- entity = form_data.get("entity")
- if entity:
- groupby_columns.append(entity)
- series = form_data.get("series")
- if series and series not in groupby_columns:
- groupby_columns.append(series)
-
- # Build query columns list: include both x_axis and groupby
- x_axis_config = form_data.get("x_axis")
- query_columns = groupby_columns.copy()
- if x_axis_config and isinstance(x_axis_config, str):
- if x_axis_config not in query_columns:
- query_columns.insert(0, x_axis_config)
- elif x_axis_config and isinstance(x_axis_config, dict):
- col_name = x_axis_config.get("column_name")
- if col_name and col_name not in query_columns:
- query_columns.insert(0, col_name)
-
# Safety net: if we could not extract any metrics or
# columns, return a clear error instead of the cryptic
# "Empty query?" that comes from deeper in the stack.
- if not metrics and not query_columns:
+ if all(
+ not query.get("metrics") and not query.get("columns")
+ for query in fallback_queries
+ ):
await ctx.warning(
"Cannot construct fallback query for chart %s "
"(viz_type=%s): no metrics, columns, or groupby "
@@ -534,26 +400,12 @@ async def get_chart_data( # noqa: C901
error_type="MissingQueryContext",
)
- _apply_extra_form_data(form_data, request.extra_form_data)
-
- fallback_query: dict[str, Any] = {
- "filters": form_data.get("filters", []),
- "columns": query_columns,
- "metrics": metrics,
- "row_limit": row_limit,
- "order_desc": True,
- }
- # Include adhoc_filters so dashboard native filters are applied
- fallback_adhoc = form_data.get("adhoc_filters")
- if fallback_adhoc:
- fallback_query["adhoc_filters"] = fallback_adhoc
-
query_context = factory.create(
datasource={
"id": chart.datasource_id,
"type": chart.datasource_type,
},
- queries=[fallback_query],
+ queries=fallback_queries,
form_data=form_data,
force=request.force_refresh,
)
@@ -566,9 +418,14 @@ async def get_chart_data( # noqa: C901
for query in query_context_json.get("queries", []):
query["row_limit"] = request.limit
- # Merge dashboard native filters into query_context's form_data
- qc_form_data = query_context_json.setdefault("form_data", {})
- _apply_extra_form_data(qc_form_data, request.extra_form_data)
+ if request.extra_form_data:
+ for query in query_context_json.get("queries", []):
+ merge_extra_form_data_filters_into_query(
+ query,
+ request.extra_form_data,
+ query_context_json["datasource"]["id"],
+ query_context_json["datasource"]["type"],
+ )
# Create QueryContext from the saved context using the schema
# This is exactly how the API does it
@@ -871,16 +728,14 @@ async def _query_from_form_data(
Used for unsaved charts where we only have form_data_key.
"""
from superset.commands.chart.data.get_data_command import ChartDataCommand
- from superset.common.query_context_factory import QueryContextFactory
datasource_id = form_data.get("datasource_id")
- datasource_type: str = form_data.get("datasource_type") or "table"
# Handle combined datasource field (e.g., "1__table")
if not datasource_id and form_data.get("datasource"):
parts = str(form_data["datasource"]).split("__")
if len(parts) == 2:
- datasource_id, datasource_type = parts[0], parts[1]
+ datasource_id = parts[0]
if not datasource_id:
return ChartError(
@@ -888,34 +743,17 @@ async def _query_from_form_data(
error_type="InvalidFormData",
)
- viz_type = form_data.get("viz_type", "unknown")
row_limit = (
request.limit or form_data.get("row_limit") or
current_app.config["ROW_LIMIT"]
)
-
- # Extract metrics and groupby based on chart type
- if viz_type in ("big_number", "big_number_total", "pop_kpi"):
- metric = form_data.get("metric")
- metrics = [metric] if metric else []
- groupby: list[str] = []
- else:
- metrics = form_data.get("metrics", [])
- groupby = list(form_data.get("groupby") or [])
+ viz_type = form_data.get("viz_type", "unknown")
try:
- factory = QueryContextFactory()
- query_context = factory.create(
- datasource={"id": datasource_id, "type": datasource_type},
- queries=[
- {
- "filters": form_data.get("filters", []),
- "columns": groupby,
- "metrics": metrics,
- "row_limit": row_limit,
- "order_desc": form_data.get("order_desc", True),
- }
- ],
- form_data=form_data,
+ query_context = build_query_context_from_form_data(
+ form_data,
+ extra_form_data=request.extra_form_data,
+ row_limit=row_limit,
+ order_desc=form_data.get("order_desc", True),
force=request.force_refresh,
)
diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py
b/superset/mcp_service/chart/tool/get_chart_preview.py
index 33281ec6fb2..798fdd4fda0 100644
--- a/superset/mcp_service/chart/tool/get_chart_preview.py
+++ b/superset/mcp_service/chart/tool/get_chart_preview.py
@@ -33,7 +33,10 @@ from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
)
-from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier
+from superset.mcp_service.chart.chart_helpers import (
+ build_query_context_from_form_data,
+ find_chart_by_identifier,
+)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
AccessibilityMetadata,
@@ -197,7 +200,6 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
def generate(self) -> ASCIIPreview | ChartError:
try:
from superset.commands.chart.data.get_data_command import
ChartDataCommand
- from superset.common.query_context_factory import
QueryContextFactory
from superset.utils import json as utils_json
form_data = utils_json.loads(self.chart.params) if
self.chart.params else {}
@@ -214,50 +216,11 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
error_type="InvalidChart",
)
- # Build query for chart data
- x_axis_config = form_data.get("x_axis")
- groupby_columns = form_data.get("groupby", [])
- metrics = form_data.get("metrics", [])
-
- # Table charts in raw mode use all_columns or columns
- all_columns = form_data.get("all_columns", [])
- raw_columns = form_data.get("columns", [])
- if form_data.get("query_mode") == "raw" and (all_columns or
raw_columns):
- columns = list(all_columns or raw_columns)
- else:
- columns = groupby_columns.copy()
- if x_axis_config and isinstance(x_axis_config, str):
- columns.append(x_axis_config)
- elif x_axis_config and isinstance(x_axis_config, dict):
- if "column_name" in x_axis_config:
- columns.append(x_axis_config["column_name"])
-
- if not columns and not metrics:
- return ChartError(
- error=(
- "Cannot generate ASCII preview: chart has no columns
or "
- "metrics in its configuration. This chart type may not
"
- "support ASCII preview."
- ),
- error_type="UnsupportedChart",
- )
-
- factory = QueryContextFactory()
- query_context = factory.create(
- datasource={
- "id": self.chart.datasource_id,
- "type": self.chart.datasource_type,
- },
- queries=[
- {
- "filters": form_data.get("filters", []),
- "columns": columns,
- "metrics": metrics,
- "row_limit": 50,
- "order_desc": True,
- }
- ],
- form_data=form_data,
+ query_context = build_query_context_from_form_data(
+ form_data,
+ chart=self.chart,
+ row_limit=50,
+ order_desc=True,
force=False,
)
@@ -303,7 +266,6 @@ class TablePreviewStrategy(PreviewFormatStrategy):
def generate(self) -> TablePreview | ChartError:
try:
from superset.commands.chart.data.get_data_command import
ChartDataCommand
- from superset.common.query_context_factory import
QueryContextFactory
from superset.utils import json as utils_json
form_data = utils_json.loads(self.chart.params) if
self.chart.params else {}
@@ -315,24 +277,11 @@ class TablePreviewStrategy(PreviewFormatStrategy):
error_type="InvalidChart",
)
- columns = _build_query_columns(form_data)
-
- factory = QueryContextFactory()
- query_context = factory.create(
- datasource={
- "id": self.chart.datasource_id,
- "type": self.chart.datasource_type,
- },
- queries=[
- {
- "filters": form_data.get("filters", []),
- "columns": columns,
- "metrics": form_data.get("metrics", []),
- "row_limit": 20,
- "order_desc": True,
- }
- ],
- form_data=form_data,
+ query_context = build_query_context_from_form_data(
+ form_data,
+ chart=self.chart,
+ row_limit=20,
+ order_desc=True,
force=False,
)
@@ -386,7 +335,6 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
# Get chart data directly using the same logic as get_chart_data
tool
# but without calling the MCP tool wrapper
from superset.commands.chart.data.get_data_command import
ChartDataCommand
- from superset.common.query_context_factory import
QueryContextFactory
from superset.daos.chart import ChartDAO
from superset.utils import json as utils_json
@@ -419,26 +367,11 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
utils_json.loads(self.chart.params) if self.chart.params
else {}
)
- # Build columns list: include both x_axis and groupby
- columns = _build_query_columns(form_data)
-
- # Create query context for data retrieval
- factory = QueryContextFactory()
- query_context = factory.create(
- datasource={
- "id": self.chart.datasource_id,
- "type": self.chart.datasource_type,
- },
- queries=[
- {
- "filters": form_data.get("filters", []),
- "columns": columns,
- "metrics": form_data.get("metrics", []),
- "row_limit": 1000, # More data for visualization
- "order_desc": True,
- }
- ],
- form_data=form_data,
+ query_context = build_query_context_from_form_data(
+ form_data,
+ chart=self.chart,
+ row_limit=1000,
+ order_desc=True,
force=self.request.force_refresh,
)
diff --git a/superset/mcp_service/chart/tool/get_chart_sql.py
b/superset/mcp_service/chart/tool/get_chart_sql.py
index 1792b928831..8555bfc56d1 100644
--- a/superset/mcp_service/chart/tool/get_chart_sql.py
+++ b/superset/mcp_service/chart/tool/get_chart_sql.py
@@ -32,6 +32,13 @@ from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.parameters import CommandParameters
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.extensions import event_logger
+from superset.mcp_service.chart.chart_helpers import (
+ build_query_context_from_form_data,
+ extract_x_axis_col,
+ resolve_groupby,
+ resolve_metrics,
+ resolve_metrics_and_groupby,
+)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
ChartError,
@@ -73,160 +80,25 @@ def _get_cached_form_data(form_data_key: str) -> str |
None:
def _resolve_metrics(form_data: dict[str, Any], viz_type: str) -> list[Any]:
"""Extract metrics from form_data, handling chart-type-specific fields."""
- # Bubble charts store measures in x, y, size fields
- if viz_type == "bubble":
- return [m for field in ("x", "y", "size") if (m :=
form_data.get(field))]
-
- metrics = form_data.get("metrics", [])
- # Fallback: some chart types store the measure as singular "metric"
- if not metrics and (metric := form_data.get("metric")):
- metrics = [metric]
- return metrics
-
-
-def _resolve_groupby(form_data: dict[str, Any]) -> list[str]:
- """Extract groupby columns from form_data with fallback aliases.
-
- Normalises scalar strings (e.g. heatmap_v2 migrated from legacy
- ``all_columns_y``) so that ``list("country")`` does not split into
- individual characters.
- """
- raw_groupby = form_data.get("groupby") or []
- if isinstance(raw_groupby, str):
- groupby: list[str] = [raw_groupby]
- else:
- groupby = list(raw_groupby)
-
- if groupby:
- return groupby
+ return resolve_metrics(form_data, viz_type)
- # Fallback: some chart types store dimensions in entity/series/columns
- for field in ("entity", "series"):
- value = form_data.get(field)
- if isinstance(value, str) and value not in groupby:
- groupby.append(value)
- form_columns = form_data.get("columns")
- if isinstance(form_columns, list):
- for col in form_columns:
- if isinstance(col, str) and col not in groupby:
- groupby.append(col)
-
- return groupby
+def _resolve_groupby(form_data: dict[str, Any]) -> list[Any]:
+ """Extract groupby columns from form_data with fallback aliases."""
+ return resolve_groupby(form_data)
def _resolve_metrics_and_groupby(
form_data: dict[str, Any],
chart: "Slice | None",
-) -> tuple[list[Any], list[str]]:
- """Resolve metrics and groupby columns from form_data.
-
- Handles chart-type-specific field names: singular ``metric`` for
- big-number variants, bubble ``x``/``y``/``size``, and fallback
- fields ``entity``, ``series``, and ``columns`` for dimensions.
- """
- viz_type = form_data.get(
- "viz_type", getattr(chart, "viz_type", "") if chart else ""
- )
-
- singular_metric_no_groupby = (
- "big_number",
- "big_number_total",
- "pop_kpi",
- )
- if viz_type in singular_metric_no_groupby:
- metrics: list[Any] = [metric] if (metric := form_data.get("metric"))
else []
- return metrics, []
-
- return _resolve_metrics(form_data, viz_type), _resolve_groupby(form_data)
+) -> tuple[list[Any], list[Any]]:
+ """Resolve metrics and groupby columns from form_data."""
+ return resolve_metrics_and_groupby(form_data, chart)
def _extract_x_axis_col(form_data: dict[str, Any]) -> str | None:
- """Return the x_axis column name from form_data, or None if not set.
-
- ``x_axis`` may be stored as a plain column-name string or as an adhoc
- column dict (``{"column_name": "...", ...}``).
- """
- x_axis = form_data.get("x_axis")
- if isinstance(x_axis, str) and x_axis:
- return x_axis
- if isinstance(x_axis, dict):
- col_name = x_axis.get("column_name")
- return col_name if isinstance(col_name, str) and col_name else None
- return None
-
-
-def _resolve_engine(
- datasource_id: Any,
- datasource_type: str,
-) -> str:
- """Return the DB engine name for *datasource_id*, or ``"base"`` on any
error."""
- if not isinstance(datasource_id, (int, str)):
- return "base"
- try:
- from superset.daos.datasource import DatasourceDAO
- from superset.utils.core import DatasourceType
-
- ds = DatasourceDAO.get_datasource(
- datasource_type=DatasourceType(datasource_type),
- database_id_or_uuid=datasource_id,
- )
- return ds.database.db_engine_spec.engine
- except Exception: # noqa: BLE001
- logger.debug("Could not resolve engine for datasource %s",
datasource_id)
- return "base"
-
-
-def _build_single_query_dict(
- form_data: dict[str, Any],
- columns: list[Any],
- metrics: list[Any],
-) -> dict[str, Any]:
- """Build one query entry for QueryContextFactory from form_data fields."""
- qd: dict[str, Any] = {"columns": columns, "metrics": metrics}
- if time_range := form_data.get("time_range"):
- qd["time_range"] = time_range
- if filters := form_data.get("filters"):
- qd["filters"] = filters
- if (row_limit := form_data.get("row_limit")) is not None:
- qd["row_limit"] = row_limit
- return qd
-
-
-def _build_mixed_timeseries_secondary(
- form_data: dict[str, Any],
- x_axis_col: str | None,
- engine: str = "base",
-) -> dict[str, Any]:
- """Build the secondary query dict for the ``mixed_timeseries`` viz type.
-
- ``mixed_timeseries`` has two independent series layers; the secondary
- layer uses ``metrics_b`` / ``groupby_b`` instead of the primary fields.
- Secondary-specific overrides (``time_range_b``, ``row_limit_b``,
- ``adhoc_filters_b``) replace the corresponding primary values so the
- generated SQL accurately reflects each series' independent configuration.
- """
- metrics_b: list[Any] = list(form_data.get("metrics_b") or [])
- raw_b = form_data.get("groupby_b") or []
- groupby_b: list[Any] = [raw_b] if isinstance(raw_b, str) else list(raw_b)
- if x_axis_col and x_axis_col not in groupby_b:
- groupby_b = [x_axis_col] + groupby_b
- qd = _build_single_query_dict(form_data, groupby_b, metrics_b)
- if time_range_b := form_data.get("time_range_b"):
- qd["time_range"] = time_range_b
- if (row_limit_b := form_data.get("row_limit_b")) is not None:
- qd["row_limit"] = row_limit_b
- # Process adhoc_filters_b into concrete filter clauses for the secondary
- # query, mirroring how split_adhoc_filters_into_base_filters handles the
- # primary adhoc_filters in _build_query_context_from_form_data.
- if adhoc_filters_b := form_data.get("adhoc_filters_b"):
- from superset.utils.core import split_adhoc_filters_into_base_filters
-
- secondary_fd: dict[str, Any] = {"adhoc_filters": adhoc_filters_b}
- split_adhoc_filters_into_base_filters(secondary_fd, engine)
- if secondary_filters := secondary_fd.get("filters"):
- qd["filters"] = secondary_filters
- return qd
+ """Return the x_axis column name from form_data, or None if not set."""
+ return extract_x_axis_col(form_data)
def _build_query_context_from_form_data(
@@ -239,85 +111,10 @@ def _build_query_context_from_form_data(
instead of executing the query.
"""
from superset.common.chart_data import ChartDataResultType
- from superset.common.query_context_factory import QueryContextFactory
-
- factory = QueryContextFactory()
-
- datasource_id = form_data.get("datasource_id")
- datasource_type = form_data.get("datasource_type")
-
- # Unsaved Explore state often stores datasource as a combined field
- # like "123__table" instead of separate datasource_id/datasource_type.
- if not datasource_id and (combined := form_data.get("datasource")):
- if isinstance(combined, str) and "__" in combined:
- parts = combined.split("__", 1)
- datasource_id = int(parts[0]) if parts[0].isdigit() else parts[0]
- datasource_type = parts[1] if len(parts) > 1 else None
-
- if not datasource_id and chart:
- datasource_id = getattr(chart, "datasource_id", None)
- if not datasource_type and chart:
- datasource_type = getattr(chart, "datasource_type", None)
-
- metrics, groupby = _resolve_metrics_and_groupby(form_data, chart)
-
- # Preprocess adhoc_filters into where/having/filters on form_data so
- # that the QueryObject receives concrete filter clauses. This mirrors
- # the view-layer call in viz.py:process_query_filters.
- from superset.utils.core import (
- merge_extra_filters,
- split_adhoc_filters_into_base_filters,
- )
-
- resolved_type_str: str = (
- datasource_type if isinstance(datasource_type, str) else "table"
- )
- engine = _resolve_engine(datasource_id, resolved_type_str)
- merge_extra_filters(form_data)
- split_adhoc_filters_into_base_filters(form_data, engine)
-
- viz_type: str = (
- form_data.get("viz_type")
- or (getattr(chart, "viz_type", "") if chart else "")
- or ""
- )
- is_timeseries = (
- viz_type.startswith("echarts_timeseries") or viz_type ==
"mixed_timeseries"
- )
-
- # For echarts_timeseries_* and mixed_timeseries charts the temporal
- # column is stored in x_axis rather than groupby. Prepend it so the
- # generated SQL includes the time axis.
- x_axis_col: str | None = None
- if is_timeseries:
- x_axis_col = _extract_x_axis_col(form_data)
- if x_axis_col and x_axis_col not in groupby:
- groupby = [x_axis_col] + groupby
-
- queries: list[dict[str, Any]] = [
- _build_single_query_dict(form_data, groupby, metrics)
- ]
-
- # mixed_timeseries exposes two independent query layers (primary and
- # secondary). Build the second query from metrics_b / groupby_b so
- # that get_chart_sql returns SQL for both and neither is silently lost.
- if viz_type == "mixed_timeseries":
- queries.append(_build_mixed_timeseries_secondary(form_data,
x_axis_col, engine))
-
- # Ensure datasource fields satisfy DatasourceDict typing requirements.
- # datasource_id must be int | str; datasource_type must be str.
- if not isinstance(datasource_id, (int, str)):
- raise ValueError(
- "Cannot determine datasource ID from form_data. "
- "Provide a chart identifier or ensure form_data contains "
- "'datasource_id' or 'datasource'."
- )
- resolved_id: int | str = datasource_id
- return factory.create(
- datasource={"id": resolved_id, "type": resolved_type_str},
- queries=queries,
- form_data=form_data,
+ return build_query_context_from_form_data(
+ form_data,
+ chart=chart,
result_type=ChartDataResultType.QUERY,
force=False,
)
diff --git a/tests/unit_tests/mcp_service/chart/test_chart_helpers.py
b/tests/unit_tests/mcp_service/chart/test_chart_helpers.py
index 5318f0fe8ac..964226ab012 100644
--- a/tests/unit_tests/mcp_service/chart/test_chart_helpers.py
+++ b/tests/unit_tests/mcp_service/chart/test_chart_helpers.py
@@ -18,9 +18,14 @@
from unittest.mock import MagicMock, patch
from superset.mcp_service.chart.chart_helpers import (
+ apply_form_data_filters_to_query,
+ build_query_dicts_from_form_data,
extract_form_data_key_from_url,
find_chart_by_identifier,
get_cached_form_data,
+ merge_extra_form_data_filters_into_query,
+ merge_form_data_filters_into_query,
+ prepare_form_data_for_query,
)
@@ -106,3 +111,177 @@ def test_get_cached_form_data_key_error(mock_init,
mock_run):
mock_init.return_value = None
result = get_cached_form_data("bad_key")
assert result is None
+
+
+def test_prepare_form_data_for_query_preserves_existing_filters_with_adhoc(
+ monkeypatch,
+):
+ monkeypatch.setattr(
+ "superset.mcp_service.chart.chart_helpers.resolve_datasource_engine",
+ lambda datasource_id, datasource_type: "base",
+ )
+ form_data = {
+ "filters": [{"col": "gender", "op": "==", "val": "boy"}],
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "girl",
+ }
+ ],
+ }
+ query = {}
+
+ prepare_form_data_for_query(form_data, 1, "table")
+ apply_form_data_filters_to_query(query, form_data)
+
+ assert query["filters"] == [
+ {"col": "gender", "op": "==", "val": "boy"},
+ {"col": "gender", "op": "==", "val": "girl"},
+ ]
+
+
+def test_prepare_form_data_for_query_merges_cached_and_request_extra_form_data(
+ monkeypatch,
+):
+ monkeypatch.setattr(
+ "superset.mcp_service.chart.chart_helpers.resolve_datasource_engine",
+ lambda datasource_id, datasource_type: "base",
+ )
+ form_data = {
+ "adhoc_filters": [],
+ "extra_form_data": {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "country",
+ "operator": "==",
+ "comparator": "US",
+ }
+ ],
+ "time_range": "Last year",
+ },
+ }
+ query = {}
+
+ prepare_form_data_for_query(
+ form_data,
+ 1,
+ "table",
+ {
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "boy",
+ }
+ ],
+ "time_range": "No filter",
+ },
+ )
+ apply_form_data_filters_to_query(query, form_data)
+
+ assert query["filters"] == [
+ {"col": "country", "op": "==", "val": "US"},
+ {"col": "gender", "op": "==", "val": "boy"},
+ ]
+ assert query["time_range"] == "No filter"
+
+
+def test_build_query_dicts_from_form_data_uses_raw_all_columns(monkeypatch):
+ monkeypatch.setattr(
+ "superset.mcp_service.chart.chart_helpers.resolve_datasource_engine",
+ lambda datasource_id, datasource_type: "base",
+ )
+ form_data = {
+ "viz_type": "handlebars",
+ "query_mode": "raw",
+ "all_columns": ["state", "city"],
+ "adhoc_filters": [],
+ }
+
+ queries = build_query_dicts_from_form_data(form_data, 1, "table")
+
+ assert queries == [
+ {
+ "columns": ["state", "city"],
+ "metrics": [],
+ "filters": [],
+ }
+ ]
+
+
+def test_merge_form_data_filters_into_query_applies_regular_overrides():
+ query = {
+ "filters": [{"col": "country", "op": "==", "val": "US"}],
+ "time_range": "Last year",
+ "granularity": "created_at",
+ "time_grain": "P1Y",
+ "time_grain_sqla": "P1Y",
+ "where": "region = 'NA'",
+ "having": "SUM(num) > 10",
+ }
+
+ merge_form_data_filters_into_query(
+ query,
+ {
+ "filters": [{"col": "gender", "op": "==", "val": "boy"}],
+ "time_range": "No filter",
+ "granularity": "updated_at",
+ "time_grain": "P1D",
+ "time_grain_sqla": "P1D",
+ "where": "name IS NOT NULL",
+ "having": "COUNT(*) > 1",
+ },
+ )
+
+ assert query["filters"] == [
+ {"col": "country", "op": "==", "val": "US"},
+ {"col": "gender", "op": "==", "val": "boy"},
+ ]
+ assert query["time_range"] == "No filter"
+ assert query["granularity"] == "updated_at"
+ assert query["time_grain"] == "P1D"
+ assert query["time_grain_sqla"] == "P1D"
+ assert query["where"] == "(region = 'NA') AND (name IS NOT NULL)"
+ assert query["having"] == "(SUM(num) > 10) AND (COUNT(*) > 1)"
+
+
+def test_merge_extra_form_data_filters_into_query_adds_only_extra_predicates(
+ monkeypatch,
+):
+ monkeypatch.setattr(
+ "superset.mcp_service.chart.chart_helpers.resolve_datasource_engine",
+ lambda datasource_id, datasource_type: "base",
+ )
+ query = {
+ "filters": [{"col": "country", "op": "==", "val": "US"}],
+ "time_range": "Last year",
+ "granularity": "created_at",
+ "time_grain_sqla": "P1Y",
+ }
+
+ merge_extra_form_data_filters_into_query(
+ query,
+ {
+ "filters": [{"col": "gender", "op": "==", "val": "boy"}],
+ "granularity_sqla": "updated_at",
+ "time_range": "No filter",
+ "time_grain_sqla": "P1D",
+ },
+ 1,
+ "table",
+ )
+
+ assert query["filters"] == [
+ {"col": "country", "op": "==", "val": "US"},
+ {"col": "gender", "op": "==", "val": "boy"},
+ ]
+ assert query["time_range"] == "No filter"
+ assert query["granularity"] == "updated_at"
+ assert query["time_grain_sqla"] == "P1D"
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
index 8d54cacfabd..b36e175d1aa 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
@@ -19,6 +19,9 @@
Tests for the get_chart_data request schema and chart type fallback handling.
"""
+import importlib
+from contextlib import nullcontext
+from types import SimpleNamespace
from typing import Any
import pytest
@@ -30,6 +33,7 @@ from superset.mcp_service.chart.schemas import (
PerformanceMetadata,
)
from superset.mcp_service.chart.tool.get_chart_data import (
+ _query_from_form_data,
_sanitize_chart_data_for_llm_context,
)
from superset.mcp_service.utils import sanitize_for_llm_context
@@ -356,6 +360,181 @@ class TestChartDataSanitization:
assert result.data[0][escaped_key] == sanitize_for_llm_context("value")
+class _AsyncContext:
+ async def report_progress(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+
+class TestUnsavedChartDataQueryConstruction:
+ @pytest.mark.asyncio
+ async def test_form_data_key_adhoc_filters_become_query_filters(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Cached form_data adhoc filters should constrain unsaved chart
data."""
+ chart_data_module = importlib.import_module(
+ "superset.mcp_service.chart.tool.get_chart_data"
+ )
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"gender": "boy", "count": 1}],
+ "colnames": ["gender", "count"],
+ "rowcount": 1,
+ }
+ ]
+ }
+
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module, "ChartDataCommand", ChartDataCommand
+ )
+ monkeypatch.setattr(
+ chart_data_module,
+ "event_logger",
+ SimpleNamespace(log_context=lambda **kwargs: nullcontext()),
+ )
+
+ adhoc_filter = {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "boy",
+ }
+
+ await _query_from_form_data(
+ {
+ "datasource_id": 1,
+ "datasource_type": "table",
+ "viz_type": "table",
+ "groupby": ["gender"],
+ "metrics": ["count"],
+ "row_limit": 10,
+ "adhoc_filters": [adhoc_filter],
+ },
+ GetChartDataRequest(form_data_key="cached-key"),
+ _AsyncContext(),
+ )
+
+ query = captured_query_contexts[0]["queries"][0]
+ assert query["filters"] == [{"col": "gender", "op": "==", "val":
"boy"}]
+ assert "adhoc_filters" not in query
+
+ @pytest.mark.asyncio
+ async def test_form_data_key_mixed_timeseries_builds_secondary_query(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Unsaved mixed-timeseries form_data should preserve both query
layers."""
+ chart_data_module = importlib.import_module(
+ "superset.mcp_service.chart.tool.get_chart_data"
+ )
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"ds": "2024-01-01", "sales": 1}],
+ "colnames": ["ds", "sales"],
+ "rowcount": 1,
+ },
+ {
+ "data": [{"ds": "2024-01-01", "profit": 2}],
+ "colnames": ["ds", "profit"],
+ "rowcount": 1,
+ },
+ ]
+ }
+
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module, "ChartDataCommand", ChartDataCommand
+ )
+ monkeypatch.setattr(
+ chart_data_module,
+ "event_logger",
+ SimpleNamespace(log_context=lambda **kwargs: nullcontext()),
+ )
+ monkeypatch.setattr(
+
"superset.mcp_service.chart.chart_helpers.resolve_datasource_engine",
+ lambda datasource_id, datasource_type: "base",
+ )
+
+ await _query_from_form_data(
+ {
+ "datasource": "1__table",
+ "viz_type": "mixed_timeseries",
+ "x_axis": "ds",
+ "groupby": ["country"],
+ "metrics": ["sum__sales"],
+ "groupby_b": ["state"],
+ "metrics_b": ["sum__profit"],
+ },
+ GetChartDataRequest(form_data_key="cached-key", limit=99),
+ _AsyncContext(),
+ )
+
+ queries = captured_query_contexts[0]["queries"]
+ assert len(queries) == 2
+ assert queries[0]["columns"] == ["ds", "country"]
+ assert queries[0]["metrics"] == ["sum__sales"]
+ assert queries[0]["row_limit"] == 99
+ assert queries[1]["columns"] == ["ds", "state"]
+ assert queries[1]["metrics"] == ["sum__profit"]
+ assert queries[1]["row_limit"] == 99
+
+
class TestWorldMapChartFallback:
"""Tests for world_map chart fallback query construction."""
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
index e5fcf909f7f..98b5e5fff7b 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py
@@ -19,6 +19,10 @@
Unit tests for get_chart_preview MCP tool
"""
+import importlib
+from types import SimpleNamespace
+from typing import Any
+
import pytest
from superset.mcp_service.chart.schemas import (
@@ -34,8 +38,11 @@ from superset.mcp_service.chart.schemas import (
)
from superset.mcp_service.chart.tool.get_chart_preview import (
_sanitize_chart_preview_for_llm_context,
+ ASCIIPreviewStrategy,
+ TablePreviewStrategy,
)
from superset.mcp_service.utils import sanitize_for_llm_context
+from superset.utils import json as utils_json
class TestPreviewXAxisInQueryContext:
@@ -277,6 +284,385 @@ class TestGetChartPreview:
# This is a structural test - actual integration tests would verify
# the tool returns data matching this structure
+ def test_table_preview_converts_saved_adhoc_filters_to_query_filters(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Saved chart adhoc filters should constrain table previews."""
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"gender": "boy", "count": 1}],
+ "colnames": ["gender", "count"],
+ "rowcount": 1,
+ }
+ ]
+ }
+
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module, "ChartDataCommand", ChartDataCommand
+ )
+
+ adhoc_filter = {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "boy",
+ }
+ chart = SimpleNamespace(
+ id=0,
+ slice_name="Unsaved Chart Preview",
+ viz_type="table",
+ datasource_id=1,
+ datasource_type="table",
+ params=utils_json.dumps(
+ {
+ "viz_type": "table",
+ "groupby": ["gender"],
+ "metrics": ["count"],
+ "adhoc_filters": [adhoc_filter],
+ }
+ ),
+ )
+
+ preview = TablePreviewStrategy(
+ chart,
+ GetChartPreviewRequest(identifier=1, format="table"),
+ ).generate()
+
+ assert isinstance(preview, TablePreview)
+ query = captured_query_contexts[0]["queries"][0]
+ assert query["filters"] == [{"col": "gender", "op": "==", "val":
"boy"}]
+ assert "adhoc_filters" not in query
+
+ def test_table_preview_uses_singular_metric(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """Preview query construction should handle charts without
metrics[]."""
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"count": 10}],
+ "colnames": ["count"],
+ "rowcount": 1,
+ }
+ ]
+ }
+
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module, "ChartDataCommand", ChartDataCommand
+ )
+
+ metric = {"label": "count", "expressionType": "SIMPLE"}
+ chart = SimpleNamespace(
+ id=0,
+ slice_name="Big Number Preview",
+ viz_type="big_number",
+ datasource_id=1,
+ datasource_type="table",
+ params=utils_json.dumps(
+ {
+ "viz_type": "big_number",
+ "metric": metric,
+ }
+ ),
+ )
+
+ preview = TablePreviewStrategy(
+ chart,
+ GetChartPreviewRequest(identifier=1, format="table"),
+ ).generate()
+
+ assert isinstance(preview, TablePreview)
+ query = captured_query_contexts[0]["queries"][0]
+ assert query["columns"] == []
+ assert query["metrics"] == [metric]
+
+ def test_ascii_preview_uses_shared_query_builder(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """ASCII preview should use chart-type-aware query construction."""
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"count": 10}],
+ "colnames": ["count"],
+ "rowcount": 1,
+ }
+ ]
+ }
+
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module, "ChartDataCommand", ChartDataCommand
+ )
+
+ metric = {"label": "count", "expressionType": "SIMPLE"}
+ chart = SimpleNamespace(
+ id=0,
+ slice_name="Big Number Preview",
+ viz_type="big_number",
+ datasource_id=1,
+ datasource_type="table",
+ params=utils_json.dumps(
+ {
+ "viz_type": "big_number",
+ "metric": metric,
+ }
+ ),
+ )
+
+ preview = ASCIIPreviewStrategy(
+ chart,
+ GetChartPreviewRequest(identifier=1, format="ascii"),
+ ).generate()
+
+ assert isinstance(preview, ASCIIPreview)
+ query = captured_query_contexts[0]["queries"][0]
+ assert query["columns"] == []
+ assert query["metrics"] == [metric]
+ assert query["row_limit"] == 50
+
+ @pytest.mark.asyncio
+ async def test_form_data_key_overrides_saved_params_for_table_preview(
+ self,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ """form_data_key should drive table preview query construction."""
+ from contextlib import nullcontext
+
+ get_chart_preview_module = importlib.import_module(
+ "superset.mcp_service.chart.tool.get_chart_preview"
+ )
+ query_context_factory_module = importlib.import_module(
+ "superset.common.query_context_factory"
+ )
+ get_data_command_module = importlib.import_module(
+ "superset.commands.chart.data.get_data_command"
+ )
+ get_form_data_module = importlib.import_module(
+ "superset.commands.explore.form_data.get"
+ )
+
+ class AsyncContext:
+ async def debug(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ async def error(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ async def info(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ async def report_progress(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ async def warning(self, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ captured_query_contexts: list[dict[str, Any]] = []
+
+ class QueryContextFactory:
+ def create(self, **kwargs: Any) -> object:
+ captured_query_contexts.append(kwargs)
+ return object()
+
+ class ChartDataCommand:
+ def __init__(self, query_context: object) -> None:
+ self.query_context = query_context
+
+ def validate(self) -> None:
+ pass
+
+ def run(self) -> dict[str, Any]:
+ return {
+ "queries": [
+ {
+ "data": [{"gender": "boy", "count": 1}],
+ "colnames": ["gender", "count"],
+ "rowcount": 1,
+ }
+ ]
+ }
+
+ saved_filter = {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "girl",
+ }
+ cached_filter = {
+ "clause": "WHERE",
+ "expressionType": "SIMPLE",
+ "subject": "gender",
+ "operator": "==",
+ "comparator": "boy",
+ }
+ chart = SimpleNamespace(
+ id=42,
+ slice_name="Saved Chart Preview",
+ viz_type="table",
+ datasource_id=1,
+ datasource_type="table",
+ params=utils_json.dumps(
+ {
+ "viz_type": "table",
+ "groupby": ["gender"],
+ "metrics": ["count"],
+ "adhoc_filters": [saved_filter],
+ }
+ ),
+ )
+
+ monkeypatch.setattr(
+ get_chart_preview_module,
+ "find_chart_by_identifier",
+ lambda identifier: chart,
+ )
+ monkeypatch.setattr(
+ get_chart_preview_module,
+ "validate_chart_dataset",
+ lambda *args, **kwargs: SimpleNamespace(
+ is_valid=True,
+ warnings=[],
+ error=None,
+ ),
+ )
+ monkeypatch.setattr(
+ get_chart_preview_module.db.session,
+ "refresh",
+ lambda chart: None,
+ )
+ monkeypatch.setattr(
+ get_chart_preview_module.event_logger,
+ "log_context",
+ lambda **kwargs: nullcontext(),
+ )
+ monkeypatch.setattr(
+ query_context_factory_module,
+ "QueryContextFactory",
+ QueryContextFactory,
+ )
+ monkeypatch.setattr(
+ get_data_command_module,
+ "ChartDataCommand",
+ ChartDataCommand,
+ )
+ monkeypatch.setattr(
+ get_form_data_module.GetFormDataCommand,
+ "__init__",
+ lambda self, cmd_params: None,
+ )
+ monkeypatch.setattr(
+ get_form_data_module.GetFormDataCommand,
+ "run",
+ lambda self: utils_json.dumps(
+ {
+ "viz_type": "table",
+ "groupby": ["gender"],
+ "metrics": ["count"],
+ "adhoc_filters": [cached_filter],
+ }
+ ),
+ )
+
+ result = await get_chart_preview_module._get_chart_preview_internal(
+ GetChartPreviewRequest(
+ identifier=42,
+ form_data_key="cached-key",
+ format="table",
+ ),
+ AsyncContext(),
+ )
+
+ assert isinstance(result, ChartPreview)
+ query = captured_query_contexts[0]["queries"][0]
+ assert query["filters"] == [{"col": "gender", "op": "==", "val":
"boy"}]
+
@pytest.mark.asyncio
async def test_preview_dimensions(self):
"""Test preview dimensions in response."""
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py
index 3e6d588fa6c..2ae4903378c 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_sql.py
@@ -869,6 +869,61 @@ class TestBuildQueryContextTimeseriesAndMixed:
secondary_filters = queries[1].get("filters", [])
assert {"col": "channel", "op": "==", "val": "organic"} in
secondary_filters
+ @patch("superset.common.query_context_factory.QueryContextFactory")
+ @patch("superset.daos.datasource.DatasourceDAO.get_datasource")
+ def test_mixed_timeseries_adhoc_filters_b_replaces_primary_sql_clauses(
+ self, mock_get_ds, mock_factory_cls
+ ):
+ """Secondary adhoc filters should not inherit primary SQL
where/having."""
+ mock_ds = Mock()
+ mock_ds.database.db_engine_spec.engine = "postgresql"
+ mock_get_ds.return_value = mock_ds
+
+ mock_factory = Mock()
+ mock_factory.create.return_value = Mock()
+ mock_factory_cls.return_value = mock_factory
+
+ form_data = {
+ "datasource_id": 1,
+ "datasource_type": "table",
+ "viz_type": "mixed_timeseries",
+ "x_axis": "ds",
+ "metrics": ["sum__revenue"],
+ "groupby": [],
+ "metrics_b": ["count"],
+ "groupby_b": [],
+ "adhoc_filters": [
+ {
+ "clause": "WHERE",
+ "expressionType": "SQL",
+ "sqlExpression": "country = 'US'",
+ },
+ {
+ "clause": "HAVING",
+ "expressionType": "SQL",
+ "sqlExpression": "SUM(revenue) > 100",
+ },
+ ],
+ "adhoc_filters_b": [
+ {
+ "clause": "WHERE",
+ "expressionType": "SQL",
+ "sqlExpression": "channel = 'organic'",
+ }
+ ],
+ }
+
+ with patch("superset.common.chart_data.ChartDataResultType") as
mock_rt:
+ mock_rt.QUERY = "QUERY"
+ _build_query_context_from_form_data(form_data, chart=None)
+
+ primary, secondary = mock_factory.create.call_args[1]["queries"]
+ assert primary["where"] == "(country = 'US')"
+ assert primary["having"] == "(SUM(revenue) > 100)"
+ assert secondary["where"] == "(channel = 'organic')"
+ assert "country = 'US'" not in secondary["where"]
+ assert "having" not in secondary
+
class TestResolveDatasourceName:
"""Tests for _resolve_datasource_name helper."""