This is an automated email from the ASF dual-hosted git repository. kgabryje pushed a commit to branch what-if in repository https://gitbox.apache.org/repos/asf/superset.git
commit 4dab58f8c0e0e30934a275c8baae856eff1e7dee Author: Kamil Gabryjelski <[email protected]> AuthorDate: Tue Dec 16 14:31:09 2025 +0100 feat: WHAT IF - backend --- superset/connectors/sqla/models.py | 72 ++++- superset/models/helpers.py | 108 ++++++- tests/unit_tests/connectors/sqla/models_test.py | 414 +++++++++++++++++++++++- 3 files changed, 582 insertions(+), 12 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8f7a60d270..9ec3ceaacf 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1516,11 +1516,79 @@ class SqlaTable( def get_from_clause( self, template_processor: BaseTemplateProcessor | None = None, + what_if: dict[str, Any] | None = None, ) -> tuple[TableClause | Alias, str | None]: if not self.is_virtual: - return self.get_sqla_table(), None + tbl = self.get_sqla_table() + if what_if: + tbl = self._apply_what_if_transform(tbl, what_if) + return tbl, None - return super().get_from_clause(template_processor) + from_clause, cte = super().get_from_clause(template_processor, what_if=None) + if what_if: + from_clause = self._apply_what_if_transform(from_clause, what_if) + return from_clause, cte + + def _apply_what_if_transform( + self, + source: TableClause | Alias, + what_if: dict[str, Any], + ) -> Alias: + """ + Wrap the source table/subquery with a subquery that applies + column transformations for what-if analysis. + + :param source: Original table or subquery to transform + :param what_if: Dict containing 'modifications' list with column/multiplier + pairs and 'needed_columns' set with columns required by the query + :returns: Aliased subquery with transformations applied + """ + modifications = what_if.get("modifications", []) + if not modifications: + return source # type: ignore + + # Build a dict of column -> multiplier + mod_map = {m["column"]: m["multiplier"] for m in modifications} + + # Get columns needed by the query + modified columns + # None means we need all columns (e.g., for complex SQL metrics) + needed_columns: set[str] | None = what_if.get("needed_columns") + modified_column_names = set(mod_map.keys()) + + # Determine which columns to select + available_columns = {col.column_name for col in self.columns} + if needed_columns is None: + # Use all available columns + columns_to_select = available_columns + else: + # Use only needed columns + modified columns + columns_to_select = needed_columns | modified_column_names + + # Build select list with only needed columns + select_columns = [] + + for col_name in columns_to_select: + # Skip columns that don't exist in the datasource + if col_name not in available_columns: + continue + + if col_name in mod_map: + # Apply transformation: column * multiplier AS column + multiplier = mod_map[col_name] + transformed = (sa.column(col_name) * sa.literal(multiplier)).label( + col_name + ) + select_columns.append(transformed) + else: + select_columns.append(sa.column(col_name)) + + if not select_columns: + # Fallback: if no columns to select, return source unchanged + return source # type: ignore + + # Create subquery with transformations + subq = sa.select(*select_columns).select_from(source) + return subq.alias("__what_if") def adhoc_metric_to_sqla( self, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index a4fb9e3fea..564e36344a 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -2016,7 +2016,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return self.db_engine_spec.get_text_clause(clause) def get_from_clause( - self, template_processor: Optional[BaseTemplateProcessor] = None + self, + template_processor: Optional[BaseTemplateProcessor] = None, + what_if: Optional[dict[str, Any]] = None, ) -> tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table @@ -2060,6 +2062,96 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return from_clause, cte + def _collect_needed_columns( # noqa: C901 + self, + columns: Optional[list[Column]] = None, + groupby: Optional[list[Column]] = None, + metrics: Optional[list[Metric]] = None, + filter: Optional[list[utils.QueryObjectFilterClause]] = None, + orderby: Optional[list[OrderBy]] = None, + granularity: Optional[str] = None, + ) -> Optional[set[str]]: + """ + Collect all column names needed by the query for what-if transformation. + This allows us to only select necessary columns instead of SELECT *. + + :returns: Set of column names that are referenced by the query, + or None if all columns should be included (e.g., for complex metrics) + """ + needed: set[str] = set() + + # Add granularity column (time column) + if granularity: + needed.add(granularity) + + # Add columns from dimensions/columns list + for col in columns or []: + if isinstance(col, str): + needed.add(col) + elif isinstance(col, dict): + if col.get("sqlExpression"): + # Adhoc column with SQL expression - can't determine columns + return None + if col.get("column_name"): + needed.add(col["column_name"]) + + # Add columns from groupby + for col in groupby or []: + if isinstance(col, str): + needed.add(col) + elif isinstance(col, dict): + if col.get("sqlExpression"): + # Adhoc column with SQL expression - can't determine columns + return None + if col.get("column_name"): + needed.add(col["column_name"]) + + # Add columns from metrics (try to extract column references) + # For complex metrics (SQL expressions or saved metrics), we need all columns + # because we can't easily parse what columns they reference + for metric in metrics or []: + if isinstance(metric, str): + # Saved metric - can't determine columns, need all + return None # Signal to use all columns + elif isinstance(metric, dict): + expression_type = metric.get("expressionType") + if expression_type == "SQL": + # SQL expression - can't determine columns, need all + return None # Signal to use all columns + # SIMPLE adhoc metric - check for column reference + metric_column = metric.get("column") + if isinstance(metric_column, dict): + col_name = metric_column.get("column_name") + if isinstance(col_name, str): + needed.add(col_name) + + # Add columns from filters + for flt in filter or []: + col = flt.get("col") + if isinstance(col, str): + needed.add(col) + elif isinstance(col, dict): + if col.get("sqlExpression"): + # Adhoc column filter - can't determine columns + return None + if col.get("column_name"): + needed.add(col["column_name"]) + + # Add columns from orderby + for order_item in orderby or []: + if isinstance(order_item, (list, tuple)) and len(order_item) >= 1: + col = order_item[0] + if isinstance(col, str): + needed.add(col) + elif isinstance(col, dict): + if col.get("sqlExpression"): + # Adhoc column orderby - can't determine columns + return None + if col.get("column_name"): + needed.add(col["column_name"]) + + return needed + def adhoc_metric_to_sqla( self, metric: AdhocMetric, @@ -2879,7 +2971,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # Process FROM clause early to populate removed_filters from virtual dataset # templates before we decide whether to add time filters - tbl, cte = self.get_from_clause(template_processor) + what_if = extras.get("what_if") if extras else None + if what_if: + # Collect columns needed by the query for efficient what-if transformation + what_if = dict(what_if) # Copy to avoid mutating original + what_if["needed_columns"] = self._collect_needed_columns( + columns=columns, + groupby=groupby, + metrics=metrics, + filter=filter, + orderby=orderby, + granularity=granularity, + ) + tbl, cte = self.get_from_clause(template_processor, what_if=what_if) if granularity: if granularity not in columns_by_name or not dttm_col: diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index 437ddf9151..41be629987 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -18,7 +18,7 @@ import pandas as pd import pytest from pytest_mock import MockerFixture -from sqlalchemy import create_engine +from sqlalchemy import create_engine, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.session import Session @@ -187,10 +187,13 @@ def test_query_datasources_by_permissions_with_catalog_schema( ["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore ) clause = db.session.query().filter_by().filter.mock_calls[0].args[0] - assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == ( - "tables.perm IN ('[my_db].[table1](id:1)') OR " - "tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501 - "tables.catalog_perm IN ('[my_db].[db1]')" + assert ( + str(clause.compile(engine, compile_kwargs={"literal_binds": True})) + == ( + "tables.perm IN ('[my_db].[table1](id:1)') OR " + "tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " # noqa: E501 + "tables.catalog_perm IN ('[my_db].[db1]')" + ) ) @@ -763,9 +766,9 @@ def test_get_sqla_table_quoting_for_cross_catalog( # The compiled SQL should contain each part quoted separately assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: {compiled}" # Should NOT have the entire identifier quoted as one string - assert not_expected_in_sql not in compiled, ( - f"Should not have {not_expected_in_sql} in SQL: {compiled}" - ) + assert ( + not_expected_in_sql not in compiled + ), f"Should not have {not_expected_in_sql} in SQL: {compiled}" def test_get_sqla_table_without_cross_catalog_ignores_catalog( @@ -842,3 +845,398 @@ def test_quoted_name_prevents_double_quoting(mocker: MockerFixture) -> None: # Should have each part quoted separately: # GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE" assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled + + +def test_apply_what_if_transform_single_modification(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform correctly transforms a single column. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + # Create table with columns + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="revenue"), + ], + ) + + # Get the base table + source = table.get_sqla_table() + + # Apply what-if transformation + what_if = { + "modifications": [{"column": "ad_spend", "multiplier": 1.1}], + "needed_columns": {"date", "ad_spend", "revenue"}, + } + result = table._apply_what_if_transform(source, what_if) + + # Compile to SQL and verify + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have the subquery alias + assert "__what_if" in compiled + # Should have the multiplication + assert "ad_spend * 1.1" in compiled + + +def test_apply_what_if_transform_multiple_modifications(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform correctly transforms multiple columns. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="revenue"), + TableColumn(column_name="conversions"), + ], + ) + + source = table.get_sqla_table() + + what_if = { + "modifications": [ + {"column": "ad_spend", "multiplier": 1.1}, + {"column": "revenue", "multiplier": 0.95}, + ], + "needed_columns": {"date", "ad_spend", "revenue"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert "ad_spend * 1.1" in compiled + assert "revenue * 0.95" in compiled + # conversions should not be in the query since it's not in needed_columns + assert "conversions" not in compiled + + +def test_apply_what_if_transform_no_modifications(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform returns source unchanged when no modifications. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="ad_spend"), + ], + ) + + source = table.get_sqla_table() + + what_if = { + "modifications": [], + "needed_columns": {"date", "ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + # Should return source unchanged + assert result is source + + +def test_apply_what_if_transform_only_needed_columns(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform only includes needed + columns plus modified columns. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + # Create table with many columns + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="col1"), + TableColumn(column_name="col2"), + TableColumn(column_name="col3"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="col5"), + ], + ) + + source = table.get_sqla_table() + + # Only need col1, but modifying ad_spend + what_if = { + "modifications": [{"column": "ad_spend", "multiplier": 1.1}], + "needed_columns": {"col1"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have col1 (needed) and ad_spend (modified) + assert "col1" in compiled + assert "ad_spend" in compiled + # Should NOT have other columns + assert "col2" not in compiled + assert "col3" not in compiled + assert "col5" not in compiled + + +def test_apply_what_if_transform_nonexistent_column(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform handles modifications for non-existent columns. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="revenue"), + ], + ) + + source = table.get_sqla_table() + + # Try to modify a column that doesn't exist + what_if = { + "modifications": [{"column": "nonexistent_column", "multiplier": 1.1}], + "needed_columns": {"date", "revenue"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should still work, just without the nonexistent column + assert "date" in compiled + assert "revenue" in compiled + assert "nonexistent_column" not in compiled + + +def test_collect_needed_columns(mocker: MockerFixture) -> None: + """ + Test that _collect_needed_columns extracts columns from query parameters. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="region"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="revenue"), + ], + ) + + # Test with various query parameters + needed = table._collect_needed_columns( + columns=["date", "region"], + groupby=["date"], + metrics=[ + { + "expressionType": "SIMPLE", + "column": {"column_name": "ad_spend"}, + "aggregate": "SUM", + } + ], + filter=[{"col": "revenue", "op": ">", "val": 100}], + orderby=[("date", True)], + granularity="date", + ) + + # Should include all referenced columns + assert "date" in needed # from columns, groupby, orderby, granularity + assert "region" in needed # from columns + assert "ad_spend" in needed # from metrics + assert "revenue" in needed # from filter + + +def test_collect_needed_columns_empty(mocker: MockerFixture) -> None: + """ + Test that _collect_needed_columns handles empty/None parameters. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="col1")], + ) + + needed = table._collect_needed_columns( + columns=None, + groupby=None, + metrics=None, + filter=None, + orderby=None, + granularity=None, + ) + + assert needed == set() + + +def test_collect_needed_columns_returns_none_for_sql_metrics( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns returns None for SQL-type adhoc metrics, + indicating all columns should be included. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="col1")], + ) + + # SQL-type adhoc metric - can't determine columns + needed = table._collect_needed_columns( + columns=["date"], + metrics=[ + { + "expressionType": "SQL", + "sqlExpression": "SUM(hidden_column)", + } + ], + ) + + # Should return None to signal all columns needed + assert needed is None + + +def test_collect_needed_columns_returns_none_for_saved_metrics( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns returns None for saved metrics (strings), + indicating all columns should be included. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="col1")], + ) + + # Saved metric (string) - can't determine columns + needed = table._collect_needed_columns( + columns=["date"], + metrics=["saved_metric_name"], + ) + + # Should return None to signal all columns needed + assert needed is None + + +def test_apply_what_if_transform_all_columns_when_needed_none( + mocker: MockerFixture, +) -> None: + """ + Test that _apply_what_if_transform includes all columns when needed_columns is None. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="col1"), + TableColumn(column_name="col2"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="col4"), + ], + ) + + source = table.get_sqla_table() + + # needed_columns is None - should include all columns + what_if = { + "modifications": [{"column": "ad_spend", "multiplier": 1.1}], + "needed_columns": None, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have all columns + assert "col1" in compiled + assert "col2" in compiled + assert "ad_spend" in compiled + assert "col4" in compiled + + +def test_collect_needed_columns_returns_none_for_adhoc_columns( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns returns None for adhoc columns + with SQL expressions, indicating all columns should be included. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="col1")], + ) + + # Adhoc column with SQL expression in columns list + needed = table._collect_needed_columns( + columns=[ + "date", + {"label": "custom_col", "sqlExpression": "CONCAT(first_name, last_name)"}, + ], + metrics=[], + ) + + # Should return None to signal all columns needed + assert needed is None + + +def test_collect_needed_columns_returns_none_for_adhoc_groupby( + mocker: MockerFixture, +) -> None: + """ + Test that _collect_needed_columns returns None for adhoc columns in groupby. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[TableColumn(column_name="col1")], + ) + + # Adhoc column in groupby + needed = table._collect_needed_columns( + groupby=[ + {"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"}, + ], + metrics=[], + ) + + assert needed is None
