This is an automated email from the ASF dual-hosted git repository. kgabryje pushed a commit to branch what-if in repository https://gitbox.apache.org/repos/asf/superset.git
commit 0a6bba1a1471801ac4e2f32850772de3ab3e64f5 Author: Kamil Gabryjelski <[email protected]> AuthorDate: Wed Dec 17 13:28:54 2025 +0100 backend for filters --- superset/connectors/sqla/models.py | 106 ++++++- tests/unit_tests/connectors/sqla/models_test.py | 399 ++++++++++++++++++++++++ 2 files changed, 496 insertions(+), 9 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 9ec3ceaacf..d8ec1fad38 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1529,6 +1529,72 @@ class SqlaTable( from_clause = self._apply_what_if_transform(from_clause, what_if) return from_clause, cte + def _build_what_if_filter_condition( + self, + filters: list[dict[str, Any]], + ) -> ColumnElement | None: + """ + Build a SQLAlchemy condition from a list of what-if filters. + + Supports operators: ==, !=, >, <, >=, <=, IN, NOT IN, TEMPORAL_RANGE + + :param filters: List of filter dicts with 'col', 'op', and 'val' keys + :returns: Combined SQLAlchemy condition (ANDed together), or None if no valid filters + """ + from superset.common.utils.time_range_utils import ( + get_since_until_from_time_range, + ) + from superset.utils.core import FilterOperator + + conditions: list[ColumnElement] = [] + available_columns = {col.column_name for col in self.columns} + + for flt in filters: + col_name = flt.get("col") + op = flt.get("op") + val = flt.get("val") + + # Skip if column doesn't exist in datasource + if col_name not in available_columns: + continue + + sqla_col = sa.column(col_name) + + if op == FilterOperator.EQUALS: + conditions.append(sqla_col == val) + elif op == FilterOperator.NOT_EQUALS: + conditions.append(sqla_col != val) + elif op == FilterOperator.GREATER_THAN: + conditions.append(sqla_col > val) + elif op == FilterOperator.LESS_THAN: + conditions.append(sqla_col < val) + elif op == FilterOperator.GREATER_THAN_OR_EQUALS: + conditions.append(sqla_col >= val) + elif op == FilterOperator.LESS_THAN_OR_EQUALS: + conditions.append(sqla_col <= val) + elif op == FilterOperator.IN: + if isinstance(val, list): + conditions.append(sqla_col.in_(val)) + elif op == FilterOperator.NOT_IN: + if isinstance(val, list): + conditions.append(~sqla_col.in_(val)) + elif op == FilterOperator.TEMPORAL_RANGE: + # Parse time range string like "2024-01-01 : 2024-03-31" or "Last week" + if isinstance(val, str): + since, until = get_since_until_from_time_range(time_range=val) + time_conditions = [] + if since: + time_conditions.append(sqla_col >= sa.literal(since)) + if until: + time_conditions.append(sqla_col < sa.literal(until)) + if time_conditions: + conditions.append(and_(*time_conditions)) + + if not conditions: + return None + + return and_(*conditions) + def _apply_what_if_transform( self, source: TableClause | Alias, @@ -1547,22 +1613,29 @@ class SqlaTable( if not modifications: return source # type: ignore - # Build a dict of column -> multiplier - mod_map = {m["column"]: m["multiplier"] for m in modifications} + # Build a dict of column -> modification config (including filters) + mod_map = {m["column"]: m for m in modifications} # Get columns needed by the query + modified columns # None means we need all columns (e.g., for complex SQL metrics) needed_columns: set[str] | None = what_if.get("needed_columns") modified_column_names = set(mod_map.keys()) + # Collect columns used in filters + filter_columns: set[str] = set() + for mod in modifications: + for flt in mod.get("filters", []): + if col_name := flt.get("col"): + filter_columns.add(col_name) + # Determine which columns to select available_columns = {col.column_name for col in self.columns} if needed_columns is None: # Use all available columns columns_to_select = available_columns else: - # Use only needed columns + modified columns - columns_to_select = needed_columns | modified_column_names + # Use only needed columns + modified columns + filter columns + columns_to_select = needed_columns | modified_column_names | filter_columns # Build select list with only needed columns select_columns = [] @@ -1573,11 +1646,26 @@ class SqlaTable( continue if col_name in mod_map: - # Apply transformation: column * multiplier AS column - multiplier = mod_map[col_name] - transformed = (sa.column(col_name) * sa.literal(multiplier)).label( - col_name - ) + mod = mod_map[col_name] + multiplier = mod["multiplier"] + filters = mod.get("filters", []) + col_ref = sa.column(col_name) + + if filters: + # Build conditional transformation with CASE WHEN + condition = self._build_what_if_filter_condition(filters) + if condition is not None: + transformed = sa.case( + (condition, col_ref * sa.literal(multiplier)), + else_=col_ref, + ).label(col_name) + else: + # No valid filter conditions, apply unconditionally + transformed = (col_ref * sa.literal(multiplier)).label(col_name) + else: + # No filters, apply transformation to all rows + transformed = (col_ref * sa.literal(multiplier)).label(col_name) + select_columns.append(transformed) else: select_columns.append(sa.column(col_name)) diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index 41be629987..8e094248a3 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -1240,3 +1240,402 @@ def test_collect_needed_columns_returns_none_for_adhoc_groupby( ) assert needed is None + + +def test_apply_what_if_transform_with_single_filter(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform generates CASE WHEN for filtered modifications. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="product"), + TableColumn(column_name="ad_spend"), + TableColumn(column_name="revenue"), + ], + ) + + source = table.get_sqla_table() + + # Apply what-if with filter: only modify ad_spend where product = 'Widget' + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.2, + "filters": [{"col": "product", "op": "==", "val": "Widget"}], + } + ], + "needed_columns": {"date", "ad_spend", "revenue"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have CASE WHEN with the filter condition + assert "CASE WHEN" in compiled + assert "product" in compiled + assert "'Widget'" in compiled + assert "ad_spend * 1.2" in compiled + assert "__what_if" in compiled + + +def test_apply_what_if_transform_with_multiple_filters(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform ANDs multiple filter conditions together. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="product"), + TableColumn(column_name="region"), + TableColumn(column_name="ad_spend"), + ], + ) + + source = table.get_sqla_table() + + # Multiple filters: product = 'Widget' AND region = 'US' + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.5, + "filters": [ + {"col": "product", "op": "==", "val": "Widget"}, + {"col": "region", "op": "==", "val": "US"}, + ], + } + ], + "needed_columns": {"date", "ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have both conditions ANDed + assert "CASE WHEN" in compiled + assert "product" in compiled + assert "'Widget'" in compiled + assert "region" in compiled + assert "'US'" in compiled + assert "AND" in compiled + assert "ad_spend * 1.5" in compiled + + +def test_apply_what_if_transform_with_in_operator(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform handles IN operator correctly. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="product"), + TableColumn(column_name="ad_spend"), + ], + ) + + source = table.get_sqla_table() + + # Filter: product IN ['Widget', 'Gadget'] + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.1, + "filters": [ + {"col": "product", "op": "IN", "val": ["Widget", "Gadget"]} + ], + } + ], + "needed_columns": {"ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have IN clause + assert "CASE WHEN" in compiled + assert "product IN" in compiled + assert "'Widget'" in compiled + assert "'Gadget'" in compiled + + +def test_apply_what_if_transform_filter_columns_included(mocker: MockerFixture) -> None: + """ + Test that filter columns are included in the subquery even if not in needed_columns. + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="date"), + TableColumn(column_name="product"), + TableColumn(column_name="ad_spend"), + ], + ) + + source = table.get_sqla_table() + + # needed_columns doesn't include 'product', but it's used in filter + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.2, + "filters": [{"col": "product", "op": "==", "val": "Widget"}], + } + ], + "needed_columns": {"date", "ad_spend"}, # product NOT included + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # product should still be in the SELECT list because it's used in filter + assert "product" in compiled + assert "CASE WHEN" in compiled + + +def test_apply_what_if_transform_comparison_operators(mocker: MockerFixture) -> None: + """ + Test that _apply_what_if_transform handles comparison operators (>, <, >=, <=). + """ + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="quantity"), + TableColumn(column_name="ad_spend"), + ], + ) + + source = table.get_sqla_table() + + # Filter: quantity >= 100 + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.3, + "filters": [{"col": "quantity", "op": ">=", "val": 100}], + } + ], + "needed_columns": {"ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + assert "CASE WHEN" in compiled + assert "quantity >= 100" in compiled + assert "ad_spend * 1.3" in compiled + + +def test_build_what_if_filter_condition_skips_nonexistent_columns( + mocker: MockerFixture, +) -> None: + """ + Test that _build_what_if_filter_condition skips filters for non-existent columns. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="product"), + TableColumn(column_name="ad_spend"), + ], + ) + + # Filter references non-existent column + filters = [ + {"col": "nonexistent", "op": "==", "val": "test"}, + {"col": "product", "op": "==", "val": "Widget"}, + ] + condition = table._build_what_if_filter_condition(filters) + + # Should still return a condition (just for product) + assert condition is not None + compiled = str(condition) + assert "product" in compiled + assert "nonexistent" not in compiled + + +def test_build_what_if_filter_condition_returns_none_for_all_invalid( + mocker: MockerFixture, +) -> None: + """ + Test that _build_what_if_filter_condition returns None if all filters are invalid. + """ + database = mocker.MagicMock() + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="product"), + ], + ) + + # All filters reference non-existent columns + filters = [ + {"col": "nonexistent1", "op": "==", "val": "test"}, + {"col": "nonexistent2", "op": "==", "val": "test"}, + ] + condition = table._build_what_if_filter_condition(filters) + + assert condition is None + + +def test_apply_what_if_transform_with_temporal_range_filter( + mocker: MockerFixture, +) -> None: + """ + Test that _apply_what_if_transform handles TEMPORAL_RANGE filter correctly. + """ + from datetime import datetime + + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="order_date"), + TableColumn(column_name="ad_spend"), + ], + ) + + # Mock get_since_until_from_time_range to avoid Flask app context requirement + mocker.patch( + "superset.common.utils.time_range_utils.get_since_until_from_time_range", + return_value=(datetime(2024, 1, 1), datetime(2024, 3, 31)), + ) + + source = table.get_sqla_table() + + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.2, + "filters": [ + { + "col": "order_date", + "op": "TEMPORAL_RANGE", + "val": "2024-01-01 : 2024-03-31", + } + ], + } + ], + "needed_columns": {"ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have CASE WHEN with time range conditions + assert "CASE WHEN" in compiled + assert "order_date >=" in compiled + assert "order_date <" in compiled + assert "ad_spend * 1.2" in compiled + + +def test_apply_what_if_transform_with_combined_filters( + mocker: MockerFixture, +) -> None: + """ + Test that _apply_what_if_transform handles combined product + time range filters. + """ + from datetime import datetime + + engine = create_engine("sqlite://") + database = mocker.MagicMock() + database.db_engine_spec.engine = "sqlite" + + table = SqlaTable( + table_name="sales", + database=database, + columns=[ + TableColumn(column_name="order_date"), + TableColumn(column_name="product"), + TableColumn(column_name="ad_spend"), + ], + ) + + # Mock get_since_until_from_time_range to avoid Flask app context requirement + mocker.patch( + "superset.common.utils.time_range_utils.get_since_until_from_time_range", + return_value=(datetime(2024, 1, 1), datetime(2024, 4, 1)), + ) + + source = table.get_sqla_table() + + # Combined filter: product = 'Widget' AND order_date in Q1 2024 + what_if = { + "modifications": [ + { + "column": "ad_spend", + "multiplier": 1.5, + "filters": [ + {"col": "product", "op": "==", "val": "Widget"}, + { + "col": "order_date", + "op": "TEMPORAL_RANGE", + "val": "2024-01-01 : 2024-04-01", + }, + ], + } + ], + "needed_columns": {"ad_spend"}, + } + result = table._apply_what_if_transform(source, what_if) + + query = select(result) + compiled = str(query.compile(engine, compile_kwargs={"literal_binds": True})) + + # Should have CASE WHEN with all conditions ANDed together + assert "CASE WHEN" in compiled + assert "product" in compiled + assert "'Widget'" in compiled + assert "order_date >=" in compiled + assert "order_date <" in compiled + assert "AND" in compiled + assert "ad_spend * 1.5" in compiled + # Both filter columns should be in the SELECT list + assert "order_date" in compiled
