This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch what-if
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 4dab58f8c0e0e30934a275c8baae856eff1e7dee
Author: Kamil Gabryjelski <[email protected]>
AuthorDate: Tue Dec 16 14:31:09 2025 +0100

    feat: WHAT IF - backend
---
 superset/connectors/sqla/models.py              |  72 ++++-
 superset/models/helpers.py                      | 108 ++++++-
 tests/unit_tests/connectors/sqla/models_test.py | 414 +++++++++++++++++++++++-
 3 files changed, 582 insertions(+), 12 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 8f7a60d270..9ec3ceaacf 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1516,11 +1516,79 @@ class SqlaTable(
     def get_from_clause(
         self,
         template_processor: BaseTemplateProcessor | None = None,
+        what_if: dict[str, Any] | None = None,
     ) -> tuple[TableClause | Alias, str | None]:
         if not self.is_virtual:
-            return self.get_sqla_table(), None
+            tbl = self.get_sqla_table()
+            if what_if:
+                tbl = self._apply_what_if_transform(tbl, what_if)
+            return tbl, None
 
-        return super().get_from_clause(template_processor)
+        from_clause, cte = super().get_from_clause(template_processor, 
what_if=None)
+        if what_if:
+            from_clause = self._apply_what_if_transform(from_clause, what_if)
+        return from_clause, cte
+
+    def _apply_what_if_transform(
+        self,
+        source: TableClause | Alias,
+        what_if: dict[str, Any],
+    ) -> Alias:
+        """
+        Wrap the source table/subquery with a subquery that applies
+        column transformations for what-if analysis.
+
+        :param source: Original table or subquery to transform
+        :param what_if: Dict containing 'modifications' list with 
column/multiplier
+            pairs and 'needed_columns' set with columns required by the query
+        :returns: Aliased subquery with transformations applied
+        """
+        modifications = what_if.get("modifications", [])
+        if not modifications:
+            return source  # type: ignore
+
+        # Build a dict of column -> multiplier
+        mod_map = {m["column"]: m["multiplier"] for m in modifications}
+
+        # Get columns needed by the query + modified columns
+        # None means we need all columns (e.g., for complex SQL metrics)
+        needed_columns: set[str] | None = what_if.get("needed_columns")
+        modified_column_names = set(mod_map.keys())
+
+        # Determine which columns to select
+        available_columns = {col.column_name for col in self.columns}
+        if needed_columns is None:
+            # Use all available columns
+            columns_to_select = available_columns
+        else:
+            # Use only needed columns + modified columns
+            columns_to_select = needed_columns | modified_column_names
+
+        # Build select list with only needed columns
+        select_columns = []
+
+        for col_name in columns_to_select:
+            # Skip columns that don't exist in the datasource
+            if col_name not in available_columns:
+                continue
+
+            if col_name in mod_map:
+                # Apply transformation: column * multiplier AS column
+                multiplier = mod_map[col_name]
+                transformed = (sa.column(col_name) * 
sa.literal(multiplier)).label(
+                    col_name
+                )
+                select_columns.append(transformed)
+            else:
+                select_columns.append(sa.column(col_name))
+
+        if not select_columns:
+            # Fallback: if no columns to select, return source unchanged
+            return source  # type: ignore
+
+        # Create subquery with transformations
+        subq = sa.select(*select_columns).select_from(source)
+        return subq.alias("__what_if")
 
     def adhoc_metric_to_sqla(
         self,
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index a4fb9e3fea..564e36344a 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -2016,7 +2016,9 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         return self.db_engine_spec.get_text_clause(clause)
 
     def get_from_clause(
-        self, template_processor: Optional[BaseTemplateProcessor] = None
+        self,
+        template_processor: Optional[BaseTemplateProcessor] = None,
+        what_if: Optional[dict[str, Any]] = None,
     ) -> tuple[Union[TableClause, Alias], Optional[str]]:
         """
         Return where to select the columns and metrics from. Either a physical 
table
@@ -2060,6 +2062,96 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
         return from_clause, cte
 
+    def _collect_needed_columns(  # noqa: C901
+        self,
+        columns: Optional[list[Column]] = None,
+        groupby: Optional[list[Column]] = None,
+        metrics: Optional[list[Metric]] = None,
+        filter: Optional[list[utils.QueryObjectFilterClause]] = None,
+        orderby: Optional[list[OrderBy]] = None,
+        granularity: Optional[str] = None,
+    ) -> Optional[set[str]]:
+        """
+        Collect all column names needed by the query for what-if 
transformation.
+        This allows us to only select necessary columns instead of SELECT *.
+
+        :returns: Set of column names that are referenced by the query,
+                  or None if all columns should be included (e.g., for complex 
metrics)
+        """
+        needed: set[str] = set()
+
+        # Add granularity column (time column)
+        if granularity:
+            needed.add(granularity)
+
+        # Add columns from dimensions/columns list
+        for col in columns or []:
+            if isinstance(col, str):
+                needed.add(col)
+            elif isinstance(col, dict):
+                if col.get("sqlExpression"):
+                    # Adhoc column with SQL expression - can't determine 
columns
+                    return None
+                if col.get("column_name"):
+                    needed.add(col["column_name"])
+
+        # Add columns from groupby
+        for col in groupby or []:
+            if isinstance(col, str):
+                needed.add(col)
+            elif isinstance(col, dict):
+                if col.get("sqlExpression"):
+                    # Adhoc column with SQL expression - can't determine 
columns
+                    return None
+                if col.get("column_name"):
+                    needed.add(col["column_name"])
+
+        # Add columns from metrics (try to extract column references)
+        # For complex metrics (SQL expressions or saved metrics), we need all 
columns
+        # because we can't easily parse what columns they reference
+        for metric in metrics or []:
+            if isinstance(metric, str):
+                # Saved metric - can't determine columns, need all
+                return None  # Signal to use all columns
+            elif isinstance(metric, dict):
+                expression_type = metric.get("expressionType")
+                if expression_type == "SQL":
+                    # SQL expression - can't determine columns, need all
+                    return None  # Signal to use all columns
+                # SIMPLE adhoc metric - check for column reference
+                metric_column = metric.get("column")
+                if isinstance(metric_column, dict):
+                    col_name = metric_column.get("column_name")
+                    if isinstance(col_name, str):
+                        needed.add(col_name)
+
+        # Add columns from filters
+        for flt in filter or []:
+            col = flt.get("col")
+            if isinstance(col, str):
+                needed.add(col)
+            elif isinstance(col, dict):
+                if col.get("sqlExpression"):
+                    # Adhoc column filter - can't determine columns
+                    return None
+                if col.get("column_name"):
+                    needed.add(col["column_name"])
+
+        # Add columns from orderby
+        for order_item in orderby or []:
+            if isinstance(order_item, (list, tuple)) and len(order_item) >= 1:
+                col = order_item[0]
+                if isinstance(col, str):
+                    needed.add(col)
+                elif isinstance(col, dict):
+                    if col.get("sqlExpression"):
+                        # Adhoc column orderby - can't determine columns
+                        return None
+                    if col.get("column_name"):
+                        needed.add(col["column_name"])
+
+        return needed
+
     def adhoc_metric_to_sqla(
         self,
         metric: AdhocMetric,
@@ -2879,7 +2971,19 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
         # Process FROM clause early to populate removed_filters from virtual 
dataset
         # templates before we decide whether to add time filters
-        tbl, cte = self.get_from_clause(template_processor)
+        what_if = extras.get("what_if") if extras else None
+        if what_if:
+            # Collect columns needed by the query for efficient what-if 
transformation
+            what_if = dict(what_if)  # Copy to avoid mutating original
+            what_if["needed_columns"] = self._collect_needed_columns(
+                columns=columns,
+                groupby=groupby,
+                metrics=metrics,
+                filter=filter,
+                orderby=orderby,
+                granularity=granularity,
+            )
+        tbl, cte = self.get_from_clause(template_processor, what_if=what_if)
 
         if granularity:
             if granularity not in columns_by_name or not dttm_col:
diff --git a/tests/unit_tests/connectors/sqla/models_test.py 
b/tests/unit_tests/connectors/sqla/models_test.py
index 437ddf9151..41be629987 100644
--- a/tests/unit_tests/connectors/sqla/models_test.py
+++ b/tests/unit_tests/connectors/sqla/models_test.py
@@ -18,7 +18,7 @@
 import pandas as pd
 import pytest
 from pytest_mock import MockerFixture
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, select
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.orm.session import Session
 
@@ -187,10 +187,13 @@ def 
test_query_datasources_by_permissions_with_catalog_schema(
         ["[my_db].[db1].[schema1]", "[my_other_db].[schema]"],  # type: ignore
     )
     clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
-    assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) 
== (
-        "tables.perm IN ('[my_db].[table1](id:1)') OR "
-        "tables.schema_perm IN ('[my_db].[db1].[schema1]', 
'[my_other_db].[schema]') OR "  # noqa: E501
-        "tables.catalog_perm IN ('[my_db].[db1]')"
+    assert (
+        str(clause.compile(engine, compile_kwargs={"literal_binds": True}))
+        == (
+            "tables.perm IN ('[my_db].[table1](id:1)') OR "
+            "tables.schema_perm IN ('[my_db].[db1].[schema1]', 
'[my_other_db].[schema]') OR "  # noqa: E501
+            "tables.catalog_perm IN ('[my_db].[db1]')"
+        )
     )
 
 
@@ -763,9 +766,9 @@ def test_get_sqla_table_quoting_for_cross_catalog(
     # The compiled SQL should contain each part quoted separately
     assert expected_in_sql in compiled, f"Expected {expected_in_sql} in SQL: 
{compiled}"
     # Should NOT have the entire identifier quoted as one string
-    assert not_expected_in_sql not in compiled, (
-        f"Should not have {not_expected_in_sql} in SQL: {compiled}"
-    )
+    assert (
+        not_expected_in_sql not in compiled
+    ), f"Should not have {not_expected_in_sql} in SQL: {compiled}"
 
 
 def test_get_sqla_table_without_cross_catalog_ignores_catalog(
@@ -842,3 +845,398 @@ def test_quoted_name_prevents_double_quoting(mocker: 
MockerFixture) -> None:
     # Should have each part quoted separately:
     # GOOD: "MY_DB"."MY_SCHEMA"."MY_TABLE"
     assert '"MY_DB"."MY_SCHEMA"."MY_TABLE"' in compiled
+
+
+def test_apply_what_if_transform_single_modification(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform correctly transforms a single column.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    # Create table with columns
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="revenue"),
+        ],
+    )
+
+    # Get the base table
+    source = table.get_sqla_table()
+
+    # Apply what-if transformation
+    what_if = {
+        "modifications": [{"column": "ad_spend", "multiplier": 1.1}],
+        "needed_columns": {"date", "ad_spend", "revenue"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    # Compile to SQL and verify
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have the subquery alias
+    assert "__what_if" in compiled
+    # Should have the multiplication
+    assert "ad_spend * 1.1" in compiled
+
+
+def test_apply_what_if_transform_multiple_modifications(mocker: MockerFixture) 
-> None:
+    """
+    Test that _apply_what_if_transform correctly transforms multiple columns.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="revenue"),
+            TableColumn(column_name="conversions"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    what_if = {
+        "modifications": [
+            {"column": "ad_spend", "multiplier": 1.1},
+            {"column": "revenue", "multiplier": 0.95},
+        ],
+        "needed_columns": {"date", "ad_spend", "revenue"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    assert "ad_spend * 1.1" in compiled
+    assert "revenue * 0.95" in compiled
+    # conversions should not be in the query since it's not in needed_columns
+    assert "conversions" not in compiled
+
+
+def test_apply_what_if_transform_no_modifications(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform returns source unchanged when no 
modifications.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    what_if = {
+        "modifications": [],
+        "needed_columns": {"date", "ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    # Should return source unchanged
+    assert result is source
+
+
+def test_apply_what_if_transform_only_needed_columns(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform only includes needed
+    columns plus modified columns.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    # Create table with many columns
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="col1"),
+            TableColumn(column_name="col2"),
+            TableColumn(column_name="col3"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="col5"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Only need col1, but modifying ad_spend
+    what_if = {
+        "modifications": [{"column": "ad_spend", "multiplier": 1.1}],
+        "needed_columns": {"col1"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have col1 (needed) and ad_spend (modified)
+    assert "col1" in compiled
+    assert "ad_spend" in compiled
+    # Should NOT have other columns
+    assert "col2" not in compiled
+    assert "col3" not in compiled
+    assert "col5" not in compiled
+
+
+def test_apply_what_if_transform_nonexistent_column(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform handles modifications for non-existent 
columns.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="revenue"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Try to modify a column that doesn't exist
+    what_if = {
+        "modifications": [{"column": "nonexistent_column", "multiplier": 1.1}],
+        "needed_columns": {"date", "revenue"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should still work, just without the nonexistent column
+    assert "date" in compiled
+    assert "revenue" in compiled
+    assert "nonexistent_column" not in compiled
+
+
+def test_collect_needed_columns(mocker: MockerFixture) -> None:
+    """
+    Test that _collect_needed_columns extracts columns from query parameters.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="region"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="revenue"),
+        ],
+    )
+
+    # Test with various query parameters
+    needed = table._collect_needed_columns(
+        columns=["date", "region"],
+        groupby=["date"],
+        metrics=[
+            {
+                "expressionType": "SIMPLE",
+                "column": {"column_name": "ad_spend"},
+                "aggregate": "SUM",
+            }
+        ],
+        filter=[{"col": "revenue", "op": ">", "val": 100}],
+        orderby=[("date", True)],
+        granularity="date",
+    )
+
+    # Should include all referenced columns
+    assert "date" in needed  # from columns, groupby, orderby, granularity
+    assert "region" in needed  # from columns
+    assert "ad_spend" in needed  # from metrics
+    assert "revenue" in needed  # from filter
+
+
+def test_collect_needed_columns_empty(mocker: MockerFixture) -> None:
+    """
+    Test that _collect_needed_columns handles empty/None parameters.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="col1")],
+    )
+
+    needed = table._collect_needed_columns(
+        columns=None,
+        groupby=None,
+        metrics=None,
+        filter=None,
+        orderby=None,
+        granularity=None,
+    )
+
+    assert needed == set()
+
+
+def test_collect_needed_columns_returns_none_for_sql_metrics(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns returns None for SQL-type adhoc metrics,
+    indicating all columns should be included.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="col1")],
+    )
+
+    # SQL-type adhoc metric - can't determine columns
+    needed = table._collect_needed_columns(
+        columns=["date"],
+        metrics=[
+            {
+                "expressionType": "SQL",
+                "sqlExpression": "SUM(hidden_column)",
+            }
+        ],
+    )
+
+    # Should return None to signal all columns needed
+    assert needed is None
+
+
+def test_collect_needed_columns_returns_none_for_saved_metrics(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns returns None for saved metrics (strings),
+    indicating all columns should be included.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="col1")],
+    )
+
+    # Saved metric (string) - can't determine columns
+    needed = table._collect_needed_columns(
+        columns=["date"],
+        metrics=["saved_metric_name"],
+    )
+
+    # Should return None to signal all columns needed
+    assert needed is None
+
+
+def test_apply_what_if_transform_all_columns_when_needed_none(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _apply_what_if_transform includes all columns when 
needed_columns is None.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="col1"),
+            TableColumn(column_name="col2"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="col4"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # needed_columns is None - should include all columns
+    what_if = {
+        "modifications": [{"column": "ad_spend", "multiplier": 1.1}],
+        "needed_columns": None,
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have all columns
+    assert "col1" in compiled
+    assert "col2" in compiled
+    assert "ad_spend" in compiled
+    assert "col4" in compiled
+
+
+def test_collect_needed_columns_returns_none_for_adhoc_columns(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns returns None for adhoc columns
+    with SQL expressions, indicating all columns should be included.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="col1")],
+    )
+
+    # Adhoc column with SQL expression in columns list
+    needed = table._collect_needed_columns(
+        columns=[
+            "date",
+            {"label": "custom_col", "sqlExpression": "CONCAT(first_name, 
last_name)"},
+        ],
+        metrics=[],
+    )
+
+    # Should return None to signal all columns needed
+    assert needed is None
+
+
+def test_collect_needed_columns_returns_none_for_adhoc_groupby(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _collect_needed_columns returns None for adhoc columns in 
groupby.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[TableColumn(column_name="col1")],
+    )
+
+    # Adhoc column in groupby
+    needed = table._collect_needed_columns(
+        groupby=[
+            {"label": "year", "sqlExpression": "EXTRACT(YEAR FROM date)"},
+        ],
+        metrics=[],
+    )
+
+    assert needed is None

Reply via email to