This is an automated email from the ASF dual-hosted git repository.

kgabryje pushed a commit to branch what-if
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 0a6bba1a1471801ac4e2f32850772de3ab3e64f5
Author: Kamil Gabryjelski <[email protected]>
AuthorDate: Wed Dec 17 13:28:54 2025 +0100

    backend for filters
---
 superset/connectors/sqla/models.py              | 106 ++++++-
 tests/unit_tests/connectors/sqla/models_test.py | 399 ++++++++++++++++++++++++
 2 files changed, 496 insertions(+), 9 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 9ec3ceaacf..d8ec1fad38 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1529,6 +1529,72 @@ class SqlaTable(
             from_clause = self._apply_what_if_transform(from_clause, what_if)
         return from_clause, cte
 
+    def _build_what_if_filter_condition(
+        self,
+        filters: list[dict[str, Any]],
+    ) -> ColumnElement | None:
+        """
+        Build a SQLAlchemy condition from a list of what-if filters.
+
+        Supports operators: ==, !=, >, <, >=, <=, IN, NOT IN, TEMPORAL_RANGE
+
+        :param filters: List of filter dicts with 'col', 'op', and 'val' keys
+        :returns: Combined SQLAlchemy condition (ANDed together), or None if 
no valid filters
+        """
+        from superset.common.utils.time_range_utils import (
+            get_since_until_from_time_range,
+        )
+        from superset.utils.core import FilterOperator
+
+        conditions: list[ColumnElement] = []
+        available_columns = {col.column_name for col in self.columns}
+
+        for flt in filters:
+            col_name = flt.get("col")
+            op = flt.get("op")
+            val = flt.get("val")
+
+            # Skip if column doesn't exist in datasource
+            if col_name not in available_columns:
+                continue
+
+            sqla_col = sa.column(col_name)
+
+            if op == FilterOperator.EQUALS:
+                conditions.append(sqla_col == val)
+            elif op == FilterOperator.NOT_EQUALS:
+                conditions.append(sqla_col != val)
+            elif op == FilterOperator.GREATER_THAN:
+                conditions.append(sqla_col > val)
+            elif op == FilterOperator.LESS_THAN:
+                conditions.append(sqla_col < val)
+            elif op == FilterOperator.GREATER_THAN_OR_EQUALS:
+                conditions.append(sqla_col >= val)
+            elif op == FilterOperator.LESS_THAN_OR_EQUALS:
+                conditions.append(sqla_col <= val)
+            elif op == FilterOperator.IN:
+                if isinstance(val, list):
+                    conditions.append(sqla_col.in_(val))
+            elif op == FilterOperator.NOT_IN:
+                if isinstance(val, list):
+                    conditions.append(~sqla_col.in_(val))
+            elif op == FilterOperator.TEMPORAL_RANGE:
+                # Parse time range string like "2024-01-01 : 2024-03-31" or 
"Last week"
+                if isinstance(val, str):
+                    since, until = 
get_since_until_from_time_range(time_range=val)
+                    time_conditions = []
+                    if since:
+                        time_conditions.append(sqla_col >= sa.literal(since))
+                    if until:
+                        time_conditions.append(sqla_col < sa.literal(until))
+                    if time_conditions:
+                        conditions.append(and_(*time_conditions))
+
+        if not conditions:
+            return None
+
+        return and_(*conditions)
+
     def _apply_what_if_transform(
         self,
         source: TableClause | Alias,
@@ -1547,22 +1613,29 @@ class SqlaTable(
         if not modifications:
             return source  # type: ignore
 
-        # Build a dict of column -> multiplier
-        mod_map = {m["column"]: m["multiplier"] for m in modifications}
+        # Build a dict of column -> modification config (including filters)
+        mod_map = {m["column"]: m for m in modifications}
 
         # Get columns needed by the query + modified columns
         # None means we need all columns (e.g., for complex SQL metrics)
         needed_columns: set[str] | None = what_if.get("needed_columns")
         modified_column_names = set(mod_map.keys())
 
+        # Collect columns used in filters
+        filter_columns: set[str] = set()
+        for mod in modifications:
+            for flt in mod.get("filters", []):
+                if col_name := flt.get("col"):
+                    filter_columns.add(col_name)
+
         # Determine which columns to select
         available_columns = {col.column_name for col in self.columns}
         if needed_columns is None:
             # Use all available columns
             columns_to_select = available_columns
         else:
-            # Use only needed columns + modified columns
-            columns_to_select = needed_columns | modified_column_names
+            # Use only needed columns + modified columns + filter columns
+            columns_to_select = needed_columns | modified_column_names | 
filter_columns
 
         # Build select list with only needed columns
         select_columns = []
@@ -1573,11 +1646,26 @@ class SqlaTable(
                 continue
 
             if col_name in mod_map:
-                # Apply transformation: column * multiplier AS column
-                multiplier = mod_map[col_name]
-                transformed = (sa.column(col_name) * 
sa.literal(multiplier)).label(
-                    col_name
-                )
+                mod = mod_map[col_name]
+                multiplier = mod["multiplier"]
+                filters = mod.get("filters", [])
+                col_ref = sa.column(col_name)
+
+                if filters:
+                    # Build conditional transformation with CASE WHEN
+                    condition = self._build_what_if_filter_condition(filters)
+                    if condition is not None:
+                        transformed = sa.case(
+                            (condition, col_ref * sa.literal(multiplier)),
+                            else_=col_ref,
+                        ).label(col_name)
+                    else:
+                        # No valid filter conditions, apply unconditionally
+                        transformed = (col_ref * 
sa.literal(multiplier)).label(col_name)
+                else:
+                    # No filters, apply transformation to all rows
+                    transformed = (col_ref * 
sa.literal(multiplier)).label(col_name)
+
                 select_columns.append(transformed)
             else:
                 select_columns.append(sa.column(col_name))
diff --git a/tests/unit_tests/connectors/sqla/models_test.py 
b/tests/unit_tests/connectors/sqla/models_test.py
index 41be629987..8e094248a3 100644
--- a/tests/unit_tests/connectors/sqla/models_test.py
+++ b/tests/unit_tests/connectors/sqla/models_test.py
@@ -1240,3 +1240,402 @@ def 
test_collect_needed_columns_returns_none_for_adhoc_groupby(
     )
 
     assert needed is None
+
+
+def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform generates CASE WHEN for filtered 
modifications.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="product"),
+            TableColumn(column_name="ad_spend"),
+            TableColumn(column_name="revenue"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Apply what-if with filter: only modify ad_spend where product = 'Widget'
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.2,
+                "filters": [{"col": "product", "op": "==", "val": "Widget"}],
+            }
+        ],
+        "needed_columns": {"date", "ad_spend", "revenue"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have CASE WHEN with the filter condition
+    assert "CASE WHEN" in compiled
+    assert "product" in compiled
+    assert "'Widget'" in compiled
+    assert "ad_spend * 1.2" in compiled
+    assert "__what_if" in compiled
+
+
+def test_apply_what_if_transform_with_multiple_filters(mocker: MockerFixture) 
-> None:
+    """
+    Test that _apply_what_if_transform ANDs multiple filter conditions 
together.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="product"),
+            TableColumn(column_name="region"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Multiple filters: product = 'Widget' AND region = 'US'
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.5,
+                "filters": [
+                    {"col": "product", "op": "==", "val": "Widget"},
+                    {"col": "region", "op": "==", "val": "US"},
+                ],
+            }
+        ],
+        "needed_columns": {"date", "ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have both conditions ANDed
+    assert "CASE WHEN" in compiled
+    assert "product" in compiled
+    assert "'Widget'" in compiled
+    assert "region" in compiled
+    assert "'US'" in compiled
+    assert "AND" in compiled
+    assert "ad_spend * 1.5" in compiled
+
+
+def test_apply_what_if_transform_with_in_operator(mocker: MockerFixture) -> 
None:
+    """
+    Test that _apply_what_if_transform handles IN operator correctly.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="product"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Filter: product IN ['Widget', 'Gadget']
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.1,
+                "filters": [
+                    {"col": "product", "op": "IN", "val": ["Widget", "Gadget"]}
+                ],
+            }
+        ],
+        "needed_columns": {"ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have IN clause
+    assert "CASE WHEN" in compiled
+    assert "product IN" in compiled
+    assert "'Widget'" in compiled
+    assert "'Gadget'" in compiled
+
+
+def test_apply_what_if_transform_filter_columns_included(mocker: 
MockerFixture) -> None:
+    """
+    Test that filter columns are included in the subquery even if not in 
needed_columns.
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="date"),
+            TableColumn(column_name="product"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # needed_columns doesn't include 'product', but it's used in filter
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.2,
+                "filters": [{"col": "product", "op": "==", "val": "Widget"}],
+            }
+        ],
+        "needed_columns": {"date", "ad_spend"},  # product NOT included
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # product should still be in the SELECT list because it's used in filter
+    assert "product" in compiled
+    assert "CASE WHEN" in compiled
+
+
+def test_apply_what_if_transform_comparison_operators(mocker: MockerFixture) 
-> None:
+    """
+    Test that _apply_what_if_transform handles comparison operators (>, <, >=, 
<=).
+    """
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="quantity"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    source = table.get_sqla_table()
+
+    # Filter: quantity >= 100
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.3,
+                "filters": [{"col": "quantity", "op": ">=", "val": 100}],
+            }
+        ],
+        "needed_columns": {"ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    assert "CASE WHEN" in compiled
+    assert "quantity >= 100" in compiled
+    assert "ad_spend * 1.3" in compiled
+
+
+def test_build_what_if_filter_condition_skips_nonexistent_columns(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _build_what_if_filter_condition skips filters for non-existent 
columns.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="product"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    # Filter references non-existent column
+    filters = [
+        {"col": "nonexistent", "op": "==", "val": "test"},
+        {"col": "product", "op": "==", "val": "Widget"},
+    ]
+    condition = table._build_what_if_filter_condition(filters)
+
+    # Should still return a condition (just for product)
+    assert condition is not None
+    compiled = str(condition)
+    assert "product" in compiled
+    assert "nonexistent" not in compiled
+
+
+def test_build_what_if_filter_condition_returns_none_for_all_invalid(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _build_what_if_filter_condition returns None if all filters are 
invalid.
+    """
+    database = mocker.MagicMock()
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="product"),
+        ],
+    )
+
+    # All filters reference non-existent columns
+    filters = [
+        {"col": "nonexistent1", "op": "==", "val": "test"},
+        {"col": "nonexistent2", "op": "==", "val": "test"},
+    ]
+    condition = table._build_what_if_filter_condition(filters)
+
+    assert condition is None
+
+
+def test_apply_what_if_transform_with_temporal_range_filter(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _apply_what_if_transform handles TEMPORAL_RANGE filter correctly.
+    """
+    from datetime import datetime
+
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="order_date"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    # Mock get_since_until_from_time_range to avoid Flask app context 
requirement
+    mocker.patch(
+        
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
+        return_value=(datetime(2024, 1, 1), datetime(2024, 3, 31)),
+    )
+
+    source = table.get_sqla_table()
+
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.2,
+                "filters": [
+                    {
+                        "col": "order_date",
+                        "op": "TEMPORAL_RANGE",
+                        "val": "2024-01-01 : 2024-03-31",
+                    }
+                ],
+            }
+        ],
+        "needed_columns": {"ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have CASE WHEN with time range conditions
+    assert "CASE WHEN" in compiled
+    assert "order_date >=" in compiled
+    assert "order_date <" in compiled
+    assert "ad_spend * 1.2" in compiled
+
+
+def test_apply_what_if_transform_with_combined_filters(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that _apply_what_if_transform handles combined product + time range 
filters.
+    """
+    from datetime import datetime
+
+    engine = create_engine("sqlite://")
+    database = mocker.MagicMock()
+    database.db_engine_spec.engine = "sqlite"
+
+    table = SqlaTable(
+        table_name="sales",
+        database=database,
+        columns=[
+            TableColumn(column_name="order_date"),
+            TableColumn(column_name="product"),
+            TableColumn(column_name="ad_spend"),
+        ],
+    )
+
+    # Mock get_since_until_from_time_range to avoid Flask app context 
requirement
+    mocker.patch(
+        
"superset.common.utils.time_range_utils.get_since_until_from_time_range",
+        return_value=(datetime(2024, 1, 1), datetime(2024, 4, 1)),
+    )
+
+    source = table.get_sqla_table()
+
+    # Combined filter: product = 'Widget' AND order_date in Q1 2024
+    what_if = {
+        "modifications": [
+            {
+                "column": "ad_spend",
+                "multiplier": 1.5,
+                "filters": [
+                    {"col": "product", "op": "==", "val": "Widget"},
+                    {
+                        "col": "order_date",
+                        "op": "TEMPORAL_RANGE",
+                        "val": "2024-01-01 : 2024-04-01",
+                    },
+                ],
+            }
+        ],
+        "needed_columns": {"ad_spend"},
+    }
+    result = table._apply_what_if_transform(source, what_if)
+
+    query = select(result)
+    compiled = str(query.compile(engine, compile_kwargs={"literal_binds": 
True}))
+
+    # Should have CASE WHEN with all conditions ANDed together
+    assert "CASE WHEN" in compiled
+    assert "product" in compiled
+    assert "'Widget'" in compiled
+    assert "order_date >=" in compiled
+    assert "order_date <" in compiled
+    assert "AND" in compiled
+    assert "ad_spend * 1.5" in compiled
+    # Both filter columns should be in the SELECT list
+    assert "order_date" in compiled

Reply via email to