This is an automated email from the ASF dual-hosted git repository.

aminghadersohi pushed a commit to branch mcp-rls-plugins-99978
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 81aad1c540f8e350e7833fc3e98ed24f0c0a0e16
Author: Alexandru Soare <[email protected]>
AuthorDate: Thu May 21 15:16:16 2026 +0300

    fix(recommandation): Fix chart recommandation (#39886)
---
 superset/mcp_service/chart/tool/get_chart_data.py  | 196 +++++++++++++++++++--
 .../mcp_service/chart/tool/test_get_chart_data.py  | 134 ++++++++++++++
 2 files changed, 317 insertions(+), 13 deletions(-)

diff --git a/superset/mcp_service/chart/tool/get_chart_data.py 
b/superset/mcp_service/chart/tool/get_chart_data.py
index 5b6c0b3a5a5..2d1b03db8fa 100644
--- a/superset/mcp_service/chart/tool/get_chart_data.py
+++ b/superset/mcp_service/chart/tool/get_chart_data.py
@@ -58,9 +58,177 @@ from superset.mcp_service.utils.oauth2_utils import (
     build_oauth2_redirect_message,
     OAUTH2_CONFIG_ERROR_MESSAGE,
 )
+from superset.utils.core import GenericDataType
 
 logger = logging.getLogger(__name__)
 
+_GENERIC_TYPE_MAP: dict[int, str] = {
+    GenericDataType.NUMERIC: "numeric",
+    GenericDataType.STRING: "string",
+    GenericDataType.TEMPORAL: "temporal",
+    GenericDataType.BOOLEAN: "boolean",
+}
+
+# Maps Superset viz_type strings to canonical categories so we can
+# avoid recommending a chart type the user already has.
+_VIZ_CATEGORY: dict[str, str] = {
+    "echarts_timeseries_line": "line",
+    "echarts_timeseries_smooth": "line",
+    "echarts_timeseries_step": "line",
+    "echarts_timeseries": "line",
+    "echarts_timeseries_bar": "bar",
+    "echarts_area": "area",
+    "echarts_timeseries_scatter": "scatter",
+    "mixed_timeseries": "line",
+    "table": "table",
+    "pie": "pie",
+    "big_number": "kpi",
+    "big_number_total": "kpi",
+    "pop_kpi": "kpi",
+    "dist_bar": "bar",
+    "line": "line",
+    "area": "area",
+    "scatter": "scatter",
+    "bubble": "bubble",
+    "treemap_v2": "treemap",
+    "sunburst_v2": "treemap",
+    "heatmap_v2": "heatmap",
+    "gauge_chart": "gauge",
+    "funnel": "funnel",
+    "histogram": "histogram",
+    "histogram_v2": "histogram",
+    "box_plot": "box_plot",
+    "world_map": "map",
+    "pivot_table_v2": "table",
+}
+
+_MAX_RECOMMENDATIONS = 4
+
+
+def _recommend_visualizations(
+    viz_type: str,
+    columns: list[DataColumn],
+    row_count: int,
+) -> list[str]:
+    """Suggest visualization types based on column types,
+    cardinality, and the chart's current viz_type.
+    """
+    if not columns:
+        return ["table"]
+
+    current_category = _VIZ_CATEGORY.get(viz_type, viz_type)
+    candidates = _build_candidates(columns, row_count)
+
+    if not candidates:
+        candidates = ["table", "bar chart"]
+
+    return _filter_candidates(candidates, current_category)
+
+
+def _build_candidates(
+    columns: list[DataColumn],
+    row_count: int,
+) -> list[str]:
+    """Build candidate visualization list from column metadata."""
+    temporal = [c for c in columns if c.data_type == "temporal"]
+    numeric = [c for c in columns if c.data_type == "numeric"]
+    categorical = [c for c in columns if c.data_type in ("string", "boolean")]
+
+    if temporal and numeric:
+        return _candidates_temporal_numeric(numeric, row_count)
+    if categorical and numeric:
+        return _candidates_categorical_numeric(numeric, categorical)
+    if len(numeric) >= 2:
+        return _candidates_multi_numeric(numeric, categorical)
+    if len(numeric) == 1 and not temporal and not categorical:
+        return _candidates_single_numeric(numeric[0], row_count)
+    return []
+
+
+def _candidates_temporal_numeric(
+    numeric: list[DataColumn], row_count: int
+) -> list[str]:
+    # Few data points are better as a bar chart than a line
+    if row_count < 5:
+        candidates = ["bar chart", "table"]
+    else:
+        candidates = ["line chart", "area chart", "bar chart"]
+        if len(numeric) > 1:
+            candidates.append("multi-line chart")
+    return candidates
+
+
+def _candidates_categorical_numeric(
+    numeric: list[DataColumn],
+    categorical: list[DataColumn],
+) -> list[str]:
+    candidates = ["bar chart"]
+    if len(numeric) == 1 and categorical[0].unique_count <= 10:
+        candidates.append("pie chart")
+    if len(numeric) >= 2:
+        candidates.append("scatter plot")
+        candidates.append("heatmap")
+    if any(c.unique_count > 5 for c in categorical):
+        candidates.append("treemap")
+    return candidates
+
+
+def _candidates_single_numeric(col: DataColumn, row_count: int) -> list[str]:
+    candidates = ["big number / KPI", "gauge chart"]
+    if row_count > 20 and col.unique_count > 10:
+        candidates.insert(0, "histogram")
+    return candidates
+
+
+def _candidates_multi_numeric(
+    numeric: list[DataColumn],
+    categorical: list[DataColumn],
+) -> list[str]:
+    candidates = ["scatter plot"]
+    if len(numeric) >= 3:
+        candidates.append("bubble chart")
+    if categorical:
+        candidates.append("heatmap")
+    return candidates
+
+
+# Maps each candidate string to a canonical category for dedup
+# against the current viz_type.
+_CANDIDATE_CATEGORY: dict[str, str] = {
+    "line chart": "line",
+    "multi-line chart": "line",
+    "area chart": "area",
+    "bar chart": "bar",
+    "scatter plot": "scatter",
+    "bubble chart": "bubble",
+    "pie chart": "pie",
+    "treemap": "treemap",
+    "heatmap": "heatmap",
+    "big number / KPI": "kpi",
+    "gauge chart": "gauge",
+    "histogram": "histogram",
+    "table": "table",
+}
+
+
+def _filter_candidates(
+    candidates: list[str],
+    current_category: str,
+) -> list[str]:
+    """Deduplicate, exclude the current viz category, and cap."""
+    seen: set[str] = set()
+    result: list[str] = []
+    for c in candidates:
+        if c in seen:
+            continue
+        if _CANDIDATE_CATEGORY.get(c) == current_category:
+            continue
+        seen.add(c)
+        result.append(c)
+        if len(result) >= _MAX_RECOMMENDATIONS:
+            break
+    return result
+
 
 def _sanitize_chart_data_for_llm_context(chart_data: ChartData) -> ChartData:
     """Wrap chart data read-path descriptive fields before LLM exposure."""
@@ -484,8 +652,9 @@ async def get_chart_data(  # noqa: C901
                 )
 
             # Create rich column metadata
+            coltypes = query_result.get("coltypes", [])
             columns = []
-            for col_name in raw_columns:
+            for idx, col_name in enumerate(raw_columns):
                 # Sample some values for metadata
                 sample_values = [
                     row.get(col_name)
@@ -493,13 +662,16 @@ async def get_chart_data(  # noqa: C901
                     if row.get(col_name) is not None
                 ]
 
-                # Infer data type
+                # Use SQL-derived GenericDataType when available,
+                # fall back to Python isinstance heuristic
                 data_type = "string"
-                if sample_values:
-                    if all(isinstance(v, (int, float)) for v in sample_values):
-                        data_type = "numeric"
-                    elif all(isinstance(v, bool) for v in sample_values):
+                if coltypes:
+                    data_type = _GENERIC_TYPE_MAP.get(coltypes[idx], "string")
+                elif sample_values:
+                    if all(isinstance(v, bool) for v in sample_values):
                         data_type = "boolean"
+                    elif all(isinstance(v, (int, float)) for v in 
sample_values):
+                        data_type = "numeric"
 
                 columns.append(
                     DataColumn(
@@ -542,13 +714,11 @@ async def get_chart_data(  # noqa: C901
             else:
                 insights.append("Fresh data retrieved from database")
 
-            recommended_visualizations = []
-            if any(
-                "time" in col.lower() or "date" in col.lower() for col in 
raw_columns
-            ):
-                recommended_visualizations.extend(["line chart", "time 
series"])
-            if len(raw_columns) <= 3:
-                recommended_visualizations.extend(["bar chart", "scatter 
plot"])
+            recommended_visualizations = _recommend_visualizations(
+                viz_type=chart.viz_type or "unknown",
+                columns=columns,
+                row_count=len(data),
+            )
 
             # Performance metadata with cache awareness
             execution_time = int((time.time() - start_time) * 1000)
diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py 
b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
index 5404f8985b6..7a368093313 100644
--- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
+++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py
@@ -33,11 +33,15 @@ from superset.mcp_service.chart.schemas import (
     PerformanceMetadata,
 )
 from superset.mcp_service.chart.tool.get_chart_data import (
+    _GENERIC_TYPE_MAP,
+    _MAX_RECOMMENDATIONS,
     _query_from_form_data,
+    _recommend_visualizations,
     _sanitize_chart_data_for_llm_context,
 )
 from superset.mcp_service.utils import sanitize_for_llm_context
 from superset.mcp_service.utils.sanitization import 
LLM_CONTEXT_ESCAPED_CLOSE_DELIMITER
+from superset.utils.core import GenericDataType
 
 
 def _collect_groupby_extras(
@@ -1167,3 +1171,133 @@ class TestChartDataCommandValidation:
                 )
 
             mock_command.run.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# Tests for _recommend_visualizations
+# ---------------------------------------------------------------------------
+
+
+def _col(
+    name: str,
+    data_type: str = "string",
+    unique_count: int = 5,
+    null_count: int = 0,
+) -> DataColumn:
+    """Shortcut to build a DataColumn for tests."""
+    return DataColumn(
+        name=name,
+        display_name=name,
+        data_type=data_type,
+        sample_values=[],
+        null_count=null_count,
+        unique_count=unique_count,
+    )
+
+
+def test_recommend_temporal_and_numeric_suggests_line_chart():
+    cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
+    result = _recommend_visualizations("table", cols, row_count=50)
+    assert "line chart" in result
+    assert "area chart" in result
+
+
+def test_recommend_categorical_and_numeric_suggests_bar_chart():
+    cols = [_col("region", "string", unique_count=5), _col("sales", "numeric")]
+    result = _recommend_visualizations("echarts_timeseries_line", cols, 
row_count=50)
+    assert "bar chart" in result
+
+
+def test_recommend_excludes_current_viz_type():
+    cols = [_col("created_at", "temporal"), _col("revenue", "numeric")]
+    result = _recommend_visualizations("echarts_timeseries_line", cols, 
row_count=50)
+    assert "line chart" not in result
+
+
+def test_recommend_multiple_numeric_suggests_scatter():
+    cols = [
+        _col("height", "numeric"),
+        _col("weight", "numeric"),
+        _col("age", "numeric"),
+    ]
+    result = _recommend_visualizations("table", cols, row_count=100)
+    assert "scatter plot" in result
+
+
+def test_recommend_single_numeric_suggests_kpi():
+    cols = [_col("total_revenue", "numeric")]
+    result = _recommend_visualizations("table", cols, row_count=1)
+    assert "big number / KPI" in result
+
+
+def test_recommend_all_strings_falls_back():
+    cols = [_col("name", "string"), _col("address", "string")]
+    result = _recommend_visualizations("pie", cols, row_count=100)
+    assert "table" in result or "bar chart" in result
+
+
+def test_recommend_high_cardinality_no_pie():
+    cols = [
+        _col("user_id", "string", unique_count=900),
+        _col("score", "numeric"),
+    ]
+    result = _recommend_visualizations("table", cols, row_count=1000)
+    assert "pie chart" not in result
+
+
+def test_recommend_caps_at_max():
+    cols = [_col("ts", "temporal"), _col("a", "numeric"), _col("b", "numeric")]
+    result = _recommend_visualizations("table", cols, row_count=100)
+    assert len(result) <= _MAX_RECOMMENDATIONS
+
+
+def test_recommend_empty_columns_returns_table():
+    result = _recommend_visualizations("table", [], row_count=0)
+    assert result == ["table"]
+
+
+def test_recommend_pie_only_for_low_cardinality():
+    cols = [
+        _col("department", "string", unique_count=25),
+        _col("headcount", "numeric"),
+    ]
+    result = _recommend_visualizations("table", cols, row_count=100)
+    assert "pie chart" not in result
+
+
+def test_recommend_temporal_few_rows_prefers_bar():
+    cols = [_col("date", "temporal"), _col("revenue", "numeric")]
+    result = _recommend_visualizations("table", cols, row_count=3)
+    assert "bar chart" in result
+    assert "line chart" not in result
+
+
+def test_recommend_single_numeric_high_cardinality_suggests_histogram():
+    cols = [_col("salary", "numeric", unique_count=500)]
+    result = _recommend_visualizations("table", cols, row_count=1000)
+    assert "histogram" in result
+
+
+def test_coltypes_populates_data_type():
+    """Verify that GenericDataType values from coltypes are mapped 
correctly."""
+    assert _GENERIC_TYPE_MAP[GenericDataType.NUMERIC] == "numeric"
+    assert _GENERIC_TYPE_MAP[GenericDataType.STRING] == "string"
+    assert _GENERIC_TYPE_MAP[GenericDataType.TEMPORAL] == "temporal"
+    assert _GENERIC_TYPE_MAP[GenericDataType.BOOLEAN] == "boolean"
+
+
+def test_bool_isinstance_check_before_int():
+    """bool is a subclass of int; verify bool check takes priority in 
fallback."""
+
+    # When coltypes is unavailable, the fallback isinstance heuristic
+    # must check bool before int/float since isinstance(True, int) is True.
+    # We verify this indirectly: if _GENERIC_TYPE_MAP handles bool correctly,
+    # and the fallback code checks bool first, booleans won't be "numeric".
+    # Direct test: simulate what the fallback does
+    sample_values = [True, False, True]
+    data_type = "string"
+    if all(isinstance(v, bool) for v in sample_values):
+        data_type = "boolean"
+    elif all(isinstance(v, (int, float)) for v in sample_values):
+        data_type = "numeric"
+    assert data_type == "boolean"

Reply via email to