This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch what-if
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 790f15d8f2b3c283955197d3575d94979bcafa79
Author: Kamil Gabryjelski <[email protected]>
AuthorDate: Thu Dec 18 17:29:18 2025 +0100

    optimize select queries
---
 superset/models/helpers.py                      | 127 ++++++++---
 tests/unit_tests/connectors/sqla/models_test.py | 271 ++++++++++++++++++++++--
 2 files changed, 350 insertions(+), 48 deletions(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 564e36344a..0bc8c521f5 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -63,6 +63,9 @@ from sqlalchemy.sql.expression import Label, Select, 
TextAsFrom
 from sqlalchemy.sql.selectable import Alias, TableClause
 from sqlalchemy_utils import UUIDType
 
+import sqlglot
+from sqlglot import exp
+
 from superset import db, is_feature_enabled
 from superset.advanced_data_type.types import AdvancedDataTypeResponse
 from superset.common.db_query_status import QueryStatus
@@ -2062,6 +2065,29 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
         return from_clause, cte
 
+    def _extract_columns_from_sql(
+        self, sql_expression: str, available_columns: set[str]
+    ) -> set[str]:
+        """
+        Extract column references from a SQL expression using sqlglot.
+
+        :param sql_expression: The SQL expression to parse
+        :param available_columns: Set of known column names in the dataset
+        :returns: Set of column names found in the expression that exist in 
available_columns
+        """
+        try:
+            # Parse the expression as a SELECT statement to handle it properly
+            parsed = sqlglot.parse_one(f"SELECT {sql_expression}")
+            found_columns: set[str] = set()
+            for column in parsed.find_all(exp.Column):
+                col_name = column.name
+                if col_name in available_columns:
+                    found_columns.add(col_name)
+            return found_columns
+        except Exception:  # noqa: BLE001
+            # If parsing fails, return all available columns as fallback
+            return available_columns
+
     def _collect_needed_columns(  # noqa: C901
         self,
         columns: Optional[list[Column]] = None,
@@ -2070,6 +2096,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         filter: Optional[list[utils.QueryObjectFilterClause]] = None,
         orderby: Optional[list[OrderBy]] = None,
         granularity: Optional[str] = None,
+        extras: Optional[dict[str, Any]] = None,
     ) -> Optional[set[str]]:
         """
         Collect all column names needed by the query for what-if 
transformation.
@@ -2079,6 +2106,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                   or None if all columns should be included (e.g., for complex 
metrics)
         """
         needed: set[str] = set()
+        available_columns = {col.column_name for col in self.columns}
 
         # Add granularity column (time column)
         if granularity:
@@ -2089,10 +2117,16 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             if isinstance(col, str):
                 needed.add(col)
             elif isinstance(col, dict):
-                if col.get("sqlExpression"):
-                    # Adhoc column with SQL expression - can't determine 
columns
-                    return None
-                if col.get("column_name"):
+                if sql_expr := col.get("sqlExpression"):
+                    # Check if it's just a simple column reference
+                    if col.get("isColumnReference"):
+                        needed.add(sql_expr)
+                    else:
+                        # Parse SQL expression to extract column references
+                        needed.update(
+                            self._extract_columns_from_sql(sql_expr, 
available_columns)
+                        )
+                elif col.get("column_name"):
                     needed.add(col["column_name"])
 
         # Add columns from groupby
@@ -2100,30 +2134,39 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             if isinstance(col, str):
                 needed.add(col)
             elif isinstance(col, dict):
-                if col.get("sqlExpression"):
-                    # Adhoc column with SQL expression - can't determine 
columns
-                    return None
-                if col.get("column_name"):
+                if sql_expr := col.get("sqlExpression"):
+                    # Check if it's just a simple column reference
+                    if col.get("isColumnReference"):
+                        needed.add(sql_expr)
+                    else:
+                        # Parse SQL expression to extract column references
+                        needed.update(
+                            self._extract_columns_from_sql(sql_expr, 
available_columns)
+                        )
+                elif col.get("column_name"):
                     needed.add(col["column_name"])
 
-        # Add columns from metrics (try to extract column references)
-        # For complex metrics (SQL expressions or saved metrics), we need all 
columns
-        # because we can't easily parse what columns they reference
+        # Add columns from metrics
         for metric in metrics or []:
             if isinstance(metric, str):
-                # Saved metric - can't determine columns, need all
-                return None  # Signal to use all columns
+                # Saved metric - we need to look it up to find the expression
+                # For now, fallback to selecting all columns
+                return None
             elif isinstance(metric, dict):
                 expression_type = metric.get("expressionType")
                 if expression_type == "SQL":
-                    # SQL expression - can't determine columns, need all
-                    return None  # Signal to use all columns
-                # SIMPLE adhoc metric - check for column reference
-                metric_column = metric.get("column")
-                if isinstance(metric_column, dict):
-                    col_name = metric_column.get("column_name")
-                    if isinstance(col_name, str):
-                        needed.add(col_name)
+                    # Parse SQL expression to extract column references
+                    if sql_expr := metric.get("sqlExpression"):
+                        needed.update(
+                            self._extract_columns_from_sql(sql_expr, 
available_columns)
+                        )
+                else:
+                    # SIMPLE adhoc metric - check for column reference
+                    metric_column = metric.get("column")
+                    if isinstance(metric_column, dict):
+                        col_name = metric_column.get("column_name")
+                        if isinstance(col_name, str):
+                            needed.add(col_name)
 
         # Add columns from filters
         for flt in filter or []:
@@ -2131,10 +2174,16 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             if isinstance(col, str):
                 needed.add(col)
             elif isinstance(col, dict):
-                if col.get("sqlExpression"):
-                    # Adhoc column filter - can't determine columns
-                    return None
-                if col.get("column_name"):
+                if sql_expr := col.get("sqlExpression"):
+                    # Check if it's just a simple column reference
+                    if col.get("isColumnReference"):
+                        needed.add(sql_expr)
+                    else:
+                        # Parse SQL expression to extract column references
+                        needed.update(
+                            self._extract_columns_from_sql(sql_expr, 
available_columns)
+                        )
+                elif col.get("column_name"):
                     needed.add(col["column_name"])
 
         # Add columns from orderby
@@ -2144,12 +2193,31 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 if isinstance(col, str):
                     needed.add(col)
                 elif isinstance(col, dict):
-                    if col.get("sqlExpression"):
-                        # Adhoc column orderby - can't determine columns
-                        return None
-                    if col.get("column_name"):
+                    if sql_expr := col.get("sqlExpression"):
+                        # Check if it's just a simple column reference
+                        if col.get("isColumnReference"):
+                            needed.add(sql_expr)
+                        else:
+                            # Parse SQL expression to extract column references
+                            needed.update(
+                                self._extract_columns_from_sql(
+                                    sql_expr, available_columns
+                                )
+                            )
+                    elif col.get("column_name"):
                         needed.add(col["column_name"])
 
+        # Add columns from extras.where and extras.having (raw SQL clauses)
+        if extras:
+            if where := extras.get("where"):
+                needed.update(
+                    self._extract_columns_from_sql(where, available_columns)
+                )
+            if having := extras.get("having"):
+                needed.update(
+                    self._extract_columns_from_sql(having, available_columns)
+                )
+
         return needed
 
     def adhoc_metric_to_sqla(
@@ -2982,6 +3050,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 filter=filter,
                 orderby=orderby,
                 granularity=granularity,
+                extras=extras,
             )
         tbl, cte = self.get_from_clause(template_processor, what_if=what_if)
 
diff --git a/tests/unit_tests/connectors/sqla/models_test.py 
b/tests/unit_tests/connectors/sqla/models_test.py
index 8e094248a3..104edc60e2 100644
--- a/tests/unit_tests/connectors/sqla/models_test.py
+++ b/tests/unit_tests/connectors/sqla/models_test.py
@@ -1094,34 +1094,41 @@ def test_collect_needed_columns_empty(mocker: 
MockerFixture) -> None:
     assert needed == set()
 
 
-def test_collect_needed_columns_returns_none_for_sql_metrics(
+def test_collect_needed_columns_extracts_from_sql_metrics(
     mocker: MockerFixture,
 ) -> None:
     """
-    Test that _collect_needed_columns returns None for SQL-type adhoc metrics,
-    indicating all columns should be included.
+    Test that _collect_needed_columns uses sqlglot to extract column references
+    from SQL-type adhoc metrics.
     """
     database = mocker.MagicMock()
 
     table = SqlaTable(
         table_name="sales",
         database=database,
-        columns=[TableColumn(column_name="col1")],
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+            TableColumn(column_name="orders"),
+        ],
     )
 
-    # SQL-type adhoc metric - can't determine columns
+    # SQL-type adhoc metric with complex expression
     needed = table._collect_needed_columns(
         columns=["date"],
         metrics=[
             {
                 "expressionType": "SQL",
-                "sqlExpression": "SUM(hidden_column)",
+                "sqlExpression": "SUM(revenue) / NULLIF(COUNT(orders), 0)",
             }
         ],
     )
 
-    # Should return None to signal all columns needed
-    assert needed is None
+    # Should extract column references using sqlglot
+    assert needed is not None
+    assert "date" in needed
+    assert "revenue" in needed
+    assert "orders" in needed
 
 
 def test_collect_needed_columns_returns_none_for_saved_metrics(
@@ -1189,57 +1196,283 @@ def 
test_apply_what_if_transform_all_columns_when_needed_none(
     assert "col4" in compiled
 
 
-def test_collect_needed_columns_returns_none_for_adhoc_columns(
+def test_collect_needed_columns_extracts_from_adhoc_columns(
     mocker: MockerFixture,
 ) -> None:
     """
-    Test that _collect_needed_columns returns None for adhoc columns
-    with SQL expressions, indicating all columns should be included.
+    Test that _collect_needed_columns uses sqlglot to extract column references
+    from adhoc columns with SQL expressions.
     """
     database = mocker.MagicMock()
 
     table = SqlaTable(
         table_name="sales",
         database=database,
-        columns=[TableColumn(column_name="col1")],
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="first_name"),
+            TableColumn(column_name="last_name"),
+        ],
     )
 
     # Adhoc column with SQL expression in columns list
     needed = table._collect_needed_columns(
         columns=[
             "date",
-            {"label": "custom_col", "sqlExpression": "CONCAT(first_name, 
last_name)"},
+            {"label": "full_name", "sqlExpression": "CONCAT(first_name, 
last_name)"},
         ],
         metrics=[],
     )
 
-    # Should return None to signal all columns needed
-    assert needed is None
+    # Should extract column references using sqlglot
+    assert needed is not None
+    assert "date" in needed
+    assert "first_name" in needed
+    assert "last_name" in needed
 
 
-def test_collect_needed_columns_returns_none_for_adhoc_groupby(
+def test_collect_needed_columns_extracts_from_adhoc_groupby(
     mocker: MockerFixture,
 ) -> None:
     """
-    Test that _collect_needed_columns returns None for adhoc columns in 
groupby.
+    Test that _collect_needed_columns uses sqlglot to extract column references
+    from adhoc columns in groupby.
     """
     database = mocker.MagicMock()
 
     table = SqlaTable(
         table_name="sales",
         database=database,
-        columns=[TableColumn(column_name="col1")],
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+        ],
     )
 
     # Adhoc column in groupby
     needed = table._collect_needed_columns(
+        columns=["revenue"],
         groupby=[
             {"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"},
         ],
         metrics=[],
     )
 
-    assert needed is None
+    # Should extract column references using sqlglot
+    assert needed is not None
+    assert "date" in needed
+    assert "revenue" in needed
+
+
+def test_collect_needed_columns_handles_column_reference(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns correctly handles columns with
+    isColumnReference=True, extracting the column name from sqlExpression.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="date"), 
TableColumn(column_name="customers")],
+    )
+
+    # Column with sqlExpression but isColumnReference=True (common in 
time-series charts)
+    needed = table._collect_needed_columns(
+        columns=[
+            {
+                "timeGrain": "P1D",
+                "columnType": "BASE_AXIS",
+                "sqlExpression": "date",
+                "label": "date",
+                "expressionType": "SQL",
+                "isColumnReference": True,
+            }
+        ],
+        metrics=[
+            {
+                "expressionType": "SIMPLE",
+                "aggregate": "SUM",
+                "column": {"column_name": "customers"},
+            }
+        ],
+    )
+
+    assert needed is not None
+    assert "date" in needed
+    assert "customers" in needed
+
+
+def test_collect_needed_columns_extracts_from_complex_sql_expression(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns uses sqlglot to extract column references
+    from complex SQL expressions.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="first_name"),
+            TableColumn(column_name="last_name"),
+        ],
+    )
+
+    # Complex SQL expression without isColumnReference
+    needed = table._collect_needed_columns(
+        columns=[
+            "date",
+            {
+                "sqlExpression": "CONCAT(first_name, ' ', last_name)",
+                "label": "full_name",
+                "expressionType": "SQL",
+            },
+        ],
+        metrics=[],
+    )
+
+    # Should extract column references using sqlglot
+    assert needed is not None
+    assert "date" in needed
+    assert "first_name" in needed
+    assert "last_name" in needed
+
+
+def test_collect_needed_columns_extracts_from_adhoc_filter(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns uses sqlglot to extract column references
+    from adhoc filters with SQL expressions.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+            TableColumn(column_name="quantity"),
+            TableColumn(column_name="price"),
+        ],
+    )
+
+    # Adhoc filter with SQL expression
+    needed = table._collect_needed_columns(
+        columns=["date"],
+        metrics=[
+            {
+                "expressionType": "SIMPLE",
+                "aggregate": "SUM",
+                "column": {"column_name": "revenue"},
+            }
+        ],
+        filter=[
+            {
+                "col": {
+                    "sqlExpression": "quantity * price",
+                    "label": "total_value",
+                    "expressionType": "SQL",
+                },
+                "op": ">",
+                "val": 1000,
+            }
+        ],
+    )
+
+    # Should extract column references using sqlglot
+    assert needed is not None
+    assert "date" in needed
+    assert "revenue" in needed
+    assert "quantity" in needed
+    assert "price" in needed
+
+
+def test_collect_needed_columns_extracts_from_extras_where(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns extracts column references from
+    extras.where (raw SQL WHERE clause).
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+            TableColumn(column_name="region"),
+            TableColumn(column_name="customers"),
+        ],
+    )
+
+    # extras.where contains raw SQL filter
+    needed = table._collect_needed_columns(
+        columns=["date"],
+        metrics=[
+            {
+                "expressionType": "SIMPLE",
+                "aggregate": "SUM",
+                "column": {"column_name": "revenue"},
+            }
+        ],
+        extras={"where": "(region = 'US')"},
+    )
+
+    # Should extract 'region' from the where clause
+    assert needed is not None
+    assert "date" in needed
+    assert "revenue" in needed
+    assert "region" in needed
+    assert "customers" not in needed  # Not referenced
+
+
+def test_collect_needed_columns_extracts_from_extras_having(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns extracts column references from
+    extras.having (raw SQL HAVING clause).
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+            TableColumn(column_name="orders"),
+        ],
+    )
+
+    # extras.having contains raw SQL filter
+    needed = table._collect_needed_columns(
+        columns=["date"],
+        metrics=[
+            {
+                "expressionType": "SIMPLE",
+                "aggregate": "SUM",
+                "column": {"column_name": "revenue"},
+            }
+        ],
+        extras={"having": "SUM(orders) > 100"},
+    )
+
+    # Should extract 'orders' from the having clause
+    assert needed is not None
+    assert "date" in needed
+    assert "revenue" in needed
+    assert "orders" in needed
 
 
 def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> 
None:

Reply via email to