This is an automated email from the ASF dual-hosted git repository. kgabryje pushed a commit to branch what-if in repository https://gitbox.apache.org/repos/asf/superset.git
commit 790f15d8f2b3c283955197d3575d94979bcafa79 Author: Kamil Gabryjelski <[email protected]> AuthorDate: Thu Dec 18 17:29:18 2025 +0100 optimize select queries --- superset/models/helpers.py | 127 ++++++++--- tests/unit_tests/connectors/sqla/models_test.py | 271 ++++++++++++++++++++++-- 2 files changed, 350 insertions(+), 48 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 564e36344a..0bc8c521f5 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -63,6 +63,9 @@ from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType +import sqlglot +from sqlglot import exp + from superset import db, is_feature_enabled from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus @@ -2062,6 +2065,29 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return from_clause, cte + def _extract_columns_from_sql( + self, sql_expression: str, available_columns: set[str] + ) -> set[str]: + """ + Extract column references from a SQL expression using sqlglot. + + :param sql_expression: The SQL expression to parse + :param available_columns: Set of known column names in the dataset + :returns: Set of column names found in the expression that exist in available_columns + """ + try: + # Parse the expression as a SELECT statement to handle it properly + parsed = sqlglot.parse_one(f"SELECT {sql_expression}") + found_columns: set[str] = set() + for column in parsed.find_all(exp.Column): + col_name = column.name + if col_name in available_columns: + found_columns.add(col_name) + return found_columns + except Exception: # noqa: BLE001 + # If parsing fails, return all available columns as fallback + return available_columns + def _collect_needed_columns( # noqa: C901 self, columns: Optional[list[Column]] = None, @@ -2070,6 +2096,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods filter: Optional[list[utils.QueryObjectFilterClause]] = None, orderby: Optional[list[OrderBy]] = None, granularity: Optional[str] = None, + extras: Optional[dict[str, Any]] = None, ) -> Optional[set[str]]: """ Collect all column names needed by the query for what-if transformation. @@ -2079,6 +2106,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods or None if all columns should be included (e.g., for complex metrics) """ needed: set[str] = set() + available_columns = {col.column_name for col in self.columns} # Add granularity column (time column) if granularity: @@ -2089,10 +2117,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if isinstance(col, str): needed.add(col) elif isinstance(col, dict): - if col.get("sqlExpression"): - # Adhoc column with SQL expression - can't determine columns - return None - if col.get("column_name"): + if sql_expr := col.get("sqlExpression"): + # Check if it's just a simple column reference + if col.get("isColumnReference"): + needed.add(sql_expr) + else: + # Parse SQL expression to extract column references + needed.update( + self._extract_columns_from_sql(sql_expr, available_columns) + ) + elif col.get("column_name"): needed.add(col["column_name"]) # Add columns from groupby @@ -2100,30 +2134,39 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if isinstance(col, str): needed.add(col) elif isinstance(col, dict): - if col.get("sqlExpression"): - # Adhoc column with SQL expression - can't determine columns - return None - if col.get("column_name"): + if sql_expr := col.get("sqlExpression"): + # Check if it's just a simple column reference + if col.get("isColumnReference"): + needed.add(sql_expr) + else: + # Parse SQL expression to extract column references + needed.update( + self._extract_columns_from_sql(sql_expr, available_columns) + ) + elif col.get("column_name"): needed.add(col["column_name"]) - # Add columns from metrics (try to extract column references) - # For complex metrics (SQL expressions or saved metrics), we need all columns - # because we can't easily parse what columns they reference + # Add columns from metrics for metric in metrics or []: if isinstance(metric, str): - # Saved metric - can't determine columns, need all - return None # Signal to use all columns + # Saved metric - we need to look it up to find the expression + # For now, fallback to selecting all columns + return None elif isinstance(metric, dict): expression_type = metric.get("expressionType") if expression_type == "SQL": - # SQL expression - can't determine columns, need all - return None # Signal to use all columns - # SIMPLE adhoc metric - check for column reference - metric_column = metric.get("column") - if isinstance(metric_column, dict): - col_name = metric_column.get("column_name") - if isinstance(col_name, str): - needed.add(col_name) + # Parse SQL expression to extract column references + if sql_expr := metric.get("sqlExpression"): + needed.update( + self._extract_columns_from_sql(sql_expr, available_columns) + ) + else: + # SIMPLE adhoc metric - check for column reference + metric_column = metric.get("column") + if isinstance(metric_column, dict): + col_name = metric_column.get("column_name") + if isinstance(col_name, str): + needed.add(col_name) # Add columns from filters for flt in filter or []: @@ -2131,10 +2174,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if isinstance(col, str): needed.add(col) elif isinstance(col, dict): - if col.get("sqlExpression"): - # Adhoc column filter - can't determine columns - return None - if col.get("column_name"): + if sql_expr := col.get("sqlExpression"): + # Check if it's just a simple column reference + if col.get("isColumnReference"): + needed.add(sql_expr) + else: + # Parse SQL expression to extract column references + needed.update( + self._extract_columns_from_sql(sql_expr, available_columns) + ) + elif col.get("column_name"): needed.add(col["column_name"]) # Add columns from orderby @@ -2144,12 +2193,31 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if isinstance(col, str): needed.add(col) elif isinstance(col, dict): - if col.get("sqlExpression"): - # Adhoc column orderby - can't determine columns - return None - if col.get("column_name"): + if sql_expr := col.get("sqlExpression"): + # Check if it's just a simple column reference + if col.get("isColumnReference"): + needed.add(sql_expr) + else: + # Parse SQL expression to extract column references + needed.update( + self._extract_columns_from_sql( + sql_expr, available_columns + ) + ) + elif col.get("column_name"): needed.add(col["column_name"]) + # Add columns from extras.where and extras.having (raw SQL clauses) + if extras: + if where := extras.get("where"): + needed.update( + self._extract_columns_from_sql(where, available_columns) + ) + if having := extras.get("having"): + needed.update( + self._extract_columns_from_sql(having, available_columns) + ) + return needed def adhoc_metric_to_sqla( @@ -2982,6 +3050,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods filter=filter, orderby=orderby, granularity=granularity, + extras=extras, ) tbl, cte = self.get_from_clause(template_processor, what_if=what_if) diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index 8e094248a3..104edc60e2 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -1094,34 +1094,41 @@ def test_collect_needed_columns_empty(mocker: MockerFixture) -> None: assert needed == set() -def test_collect_needed_columns_returns_none_for_sql_metrics( +def test_collect_needed_columns_extracts_from_sql_metrics( mocker: MockerFixture, ) -> None: """ - Test that _collect_needed_columns returns None for SQL-type adhoc metrics, - indicating all columns should be included. + Test that _collect_needed_columns uses sqlglot to extract column references + from SQL-type adhoc metrics. """ database = mocker.MagicMock() table = SqlaTable( table_name="sales", database=database, - columns=[TableColumn(column_name="col1")], + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + TableColumn(column_name="orders"), + ], ) - # SQL-type adhoc metric - can't determine columns + # SQL-type adhoc metric with complex expression needed = table._collect_needed_columns( columns=["date"], metrics=[ { "expressionType": "SQL", - "sqlExpression": "SUM(hidden_column)", + "sqlExpression": "SUM(revenue) / NULLIF(COUNT(orders), 0)", } ], ) - # Should return None to signal all columns needed - assert needed is None + # Should extract column references using sqlglot + assert needed is not None + assert "date" in needed + assert "revenue" in needed + assert "orders" in needed def test_collect_needed_columns_returns_none_for_saved_metrics( @@ -1189,57 +1196,283 @@ def test_apply_what_if_transform_all_columns_when_needed_none( assert "col4" in compiled -def test_collect_needed_columns_returns_none_for_adhoc_columns( +def test_collect_needed_columns_extracts_from_adhoc_columns( mocker: MockerFixture, ) -> None: """ - Test that _collect_needed_columns returns None for adhoc columns - with SQL expressions, indicating all columns should be included. + Test that _collect_needed_columns uses sqlglot to extract column references + from adhoc columns with SQL expressions. """ database = mocker.MagicMock() table = SqlaTable( table_name="sales", database=database, - columns=[TableColumn(column_name="col1")], + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="first_name"), + TableColumn(column_name="last_name"), + ], ) # Adhoc column with SQL expression in columns list needed = table._collect_needed_columns( columns=[ "date", - {"label": "custom_col", "sqlExpression": "CONCAT(first_name, last_name)"}, + {"label": "full_name", "sqlExpression": "CONCAT(first_name, last_name)"}, ], metrics=[], ) - # Should return None to signal all columns needed - assert needed is None + # Should extract column references using sqlglot + assert needed is not None + assert "date" in needed + assert "first_name" in needed + assert "last_name" in needed -def test_collect_needed_columns_returns_none_for_adhoc_groupby( +def test_collect_needed_columns_extracts_from_adhoc_groupby( mocker: MockerFixture, ) -> None: """ - Test that _collect_needed_columns returns None for adhoc columns in groupby. + Test that _collect_needed_columns uses sqlglot to extract column references + from adhoc columns in groupby. """ database = mocker.MagicMock() table = SqlaTable( table_name="sales", database=database, - columns=[TableColumn(column_name="col1")], + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + ], ) # Adhoc column in groupby needed = table._collect_needed_columns( + columns=["revenue"], groupby=[ {"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"}, ], metrics=[], ) - assert needed is None + # Should extract column references using sqlglot + assert needed is not None + assert "date" in needed + assert "revenue" in needed + + +def test_collect_needed_columns_handles_column_reference( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns correctly handles columns with + isColumnReference=True, extracting the column name from sqlExpression. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="date"), TableColumn(column_name="customers")], + ) + + # Column with sqlExpression but isColumnReference=True (common in time-series charts) + needed = table._collect_needed_columns( + columns=[ + { + "timeGrain": "P1D", + "columnType": "BASE_AXIS", + "sqlExpression": "date", + "label": "date", + "expressionType": "SQL", + "isColumnReference": True, + } + ], + metrics=[ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "column": {"column_name": "customers"}, + } + ], + ) + + assert needed is not None + assert "date" in needed + assert "customers" in needed + + +def test_collect_needed_columns_extracts_from_complex_sql_expression( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns uses sqlglot to extract column references + from complex SQL expressions. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="first_name"), + TableColumn(column_name="last_name"), + ], + ) + + # Complex SQL expression without isColumnReference + needed = table._collect_needed_columns( + columns=[ + "date", + { + "sqlExpression": "CONCAT(first_name, ' ', last_name)", + "label": "full_name", + "expressionType": "SQL", + }, + ], + metrics=[], + ) + + # Should extract column references using sqlglot + assert needed is not None + assert "date" in needed + assert "first_name" in needed + assert "last_name" in needed + + +def test_collect_needed_columns_extracts_from_adhoc_filter( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns uses sqlglot to extract column references + from adhoc filters with SQL expressions. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + TableColumn(column_name="quantity"), + TableColumn(column_name="price"), + ], + ) + + # Adhoc filter with SQL expression + needed = table._collect_needed_columns( + columns=["date"], + metrics=[ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "column": {"column_name": "revenue"}, + } + ], + filter=[ + { + "col": { + "sqlExpression": "quantity * price", + "label": "total_value", + "expressionType": "SQL", + }, + "op": ">", + "val": 1000, + } + ], + ) + + # Should extract column references using sqlglot + assert needed is not None + assert "date" in needed + assert "revenue" in needed + assert "quantity" in needed + assert "price" in needed + + +def test_collect_needed_columns_extracts_from_extras_where( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns extracts column references from + extras.where (raw SQL WHERE clause). + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + TableColumn(column_name="region"), + TableColumn(column_name="customers"), + ], + ) + + # extras.where contains raw SQL filter + needed = table._collect_needed_columns( + columns=["date"], + metrics=[ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "column": {"column_name": "revenue"}, + } + ], + extras={"where": "(region = 'US')"}, + ) + + # Should extract 'region' from the where clause + assert needed is not None + assert "date" in needed + assert "revenue" in needed + assert "region" in needed + assert "customers" not in needed # Not referenced + + +def test_collect_needed_columns_extracts_from_extras_having( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns extracts column references from + extras.having (raw SQL HAVING clause). + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + TableColumn(column_name="orders"), + ], + ) + + # extras.having contains raw SQL filter + needed = table._collect_needed_columns( + columns=["date"], + metrics=[ + { + "expressionType": "SIMPLE", + "aggregate": "SUM", + "column": {"column_name": "revenue"}, + } + ], + extras={"having": "SUM(orders) > 100"}, + ) + + # Should extract 'orders' from the having clause + assert needed is not None + assert "date" in needed + assert "revenue" in needed + assert "orders" in needed def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> None:
