This is an automated email from the ASF dual-hosted git repository. vavila pushed a commit to branch fix/cache-keys-from-columns-metrics in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8deea9ac0c87416ff57216e799af166571fe31b8 Author: Vitor Avila <[email protected]> AuthorDate: Mon Oct 28 18:11:25 2024 -0300 fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja --- superset/connectors/sqla/models.py | 29 ++++++++++-- tests/integration_tests/sqla_models_tests.py | 68 ++++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 75354ad355..3451d6f362 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -116,7 +116,12 @@ from superset.superset_typing import ( ) from superset.utils import core as utils, json from superset.utils.backports import StrEnum -from superset.utils.core import GenericDataType, is_adhoc_column, MediumText +from superset.utils.core import ( + GenericDataType, + is_adhoc_column, + is_adhoc_metric, + MediumText, +) config = app.config metadata = Model.metadata # pylint: disable=no-member @@ -1980,10 +1985,24 @@ class SqlaTable( templatable_statements.append(extras["where"]) if "having" in extras: templatable_statements.append(extras["having"]) - if "columns" in query_obj: - templatable_statements += [ - c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c) - ] + if columns := query_obj.get("columns"): + calculated_columns: dict[str, Any] = { + c.column_name: c.expression for c in self.columns if c.expression + } + for column_ in columns: + if is_adhoc_column(column_): + templatable_statements.append(column_["sqlExpression"]) + elif isinstance(column_, str) and column_ in calculated_columns: + templatable_statements.append(calculated_columns[column_]) + if metrics := query_obj.get("metrics"): + metrics_by_name: dict[str, str] = { + m.metric_name: m.expression for m in self.metrics + } + for metric in metrics: + if is_adhoc_metric(metric): + templatable_statements.append(metric["sqlExpression"]) # type: ignore + elif isinstance(metric, str) and metric in metrics_by_name: + templatable_statements.append(metrics_by_name[metric]) if self.is_rls_supported: templatable_statements += [ f.clause for f in security_manager.get_rls_filters(self) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 79d4bf00ed..3d723c95e2 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -913,47 +913,53 @@ def test_extra_cache_keys_in_sql_expression( @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( - "sql_expression,expected_cache_keys,has_extra_cache_keys", + "sql_expression,expected_cache_keys,has_extra_cache_keys,item_type", [ - ("'{{ current_username() }}'", ["abc"], True), - ("(user != 'abc')", [], False), + ("'{{ current_username() }}'", ["abc"], True, "columns"), + ("(user != 'abc')", [], False, "columns"), + ("{{ current_user_id() }}", [1], True, "metrics"), + ("COUNT(*)", [], False, "metrics"), ], ) @patch("superset.jinja_context.get_user_id", return_value=1) @patch("superset.jinja_context.get_username", return_value="abc") -@patch("superset.jinja_context.get_user_email", return_value="[email protected]") -def test_extra_cache_keys_in_columns( - mock_user_email, +def test_extra_cache_keys_in_adhoc_metrics_and_columns( mock_username, mock_user_id, sql_expression, expected_cache_keys, has_extra_cache_keys, + item_type, ): table = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql="SELECT 'abc' as user", database=get_example_database(), ) + base_type = "metrics" if item_type == "columns" else "columns" base_query_obj = { "granularity": None, "from_dttm": None, "to_dttm": None, "groupby": [], - "metrics": [], + base_type: [], "is_timeseries": False, "filter": [], } - query_obj = dict( - **base_query_obj, - columns=[ + items = { + item_type: [ { "label": None, "expressionType": "SQL", "sqlExpression": sql_expression, } ], + } + + query_obj = dict( + **base_query_obj, + **items, ) extra_cache_keys = table.get_extra_cache_keys(query_obj) @@ -961,6 +967,48 @@ def test_extra_cache_keys_in_columns( assert extra_cache_keys == expected_cache_keys [email protected]("app_context") +@patch("superset.jinja_context.get_user_id", return_value=1) +@patch("superset.jinja_context.get_username", return_value="abc") +def test_extra_cache_keys_in_dataset_metrics_and_columns( + mock_username, + mock_user_id, +): + table = SqlaTable( + table_name="test_has_no_extra_cache_keys_table", + sql="SELECT 'abc' as user", + database=get_example_database(), + columns=[ + TableColumn(column_name="user", type="VARCHAR(255)"), + TableColumn( + column_name="username", + type="VARCHAR(255)", + expression="{{ current_username() }}", + ), + ], + metrics=[ + SqlMetric( + metric_name="variable_profit", + expression="SUM(price) * {{ url_param('multiplier') }}", + ), + ], + ) + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": [], + "columns": ["username"], + "metrics": ["variable_profit"], + "is_timeseries": False, + "filter": [], + } + + extra_cache_keys = table.get_extra_cache_keys(query_obj) + assert table.has_extra_cache_key_calls(query_obj) is True + assert set(extra_cache_keys) == {"abc", None} + + @pytest.mark.usefixtures("app_context") @pytest.mark.parametrize( "row,dimension,result",
