This is an automated email from the ASF dual-hosted git repository.

vavila pushed a commit to branch fix/cache-keys-from-columns-metrics
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8deea9ac0c87416ff57216e799af166571fe31b8
Author: Vitor Avila <[email protected]>
AuthorDate: Mon Oct 28 18:11:25 2024 -0300

    fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja
---
 superset/connectors/sqla/models.py           | 29 ++++++++++--
 tests/integration_tests/sqla_models_tests.py | 68 ++++++++++++++++++++++++----
 2 files changed, 82 insertions(+), 15 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 75354ad355..3451d6f362 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -116,7 +116,12 @@ from superset.superset_typing import (
 )
 from superset.utils import core as utils, json
 from superset.utils.backports import StrEnum
-from superset.utils.core import GenericDataType, is_adhoc_column, MediumText
+from superset.utils.core import (
+    GenericDataType,
+    is_adhoc_column,
+    is_adhoc_metric,
+    MediumText,
+)
 
 config = app.config
 metadata = Model.metadata  # pylint: disable=no-member
@@ -1980,10 +1985,24 @@ class SqlaTable(
             templatable_statements.append(extras["where"])
         if "having" in extras:
             templatable_statements.append(extras["having"])
-        if "columns" in query_obj:
-            templatable_statements += [
-                c["sqlExpression"] for c in query_obj["columns"] if 
is_adhoc_column(c)
-            ]
+        if columns := query_obj.get("columns"):
+            calculated_columns: dict[str, Any] = {
+                c.column_name: c.expression for c in self.columns if 
c.expression
+            }
+            for column_ in columns:
+                if is_adhoc_column(column_):
+                    templatable_statements.append(column_["sqlExpression"])
+                elif isinstance(column_, str) and column_ in 
calculated_columns:
+                    templatable_statements.append(calculated_columns[column_])
+        if metrics := query_obj.get("metrics"):
+            metrics_by_name: dict[str, str] = {
+                m.metric_name: m.expression for m in self.metrics
+            }
+            for metric in metrics:
+                if is_adhoc_metric(metric):
+                    templatable_statements.append(metric["sqlExpression"])  # 
type: ignore
+                elif isinstance(metric, str) and metric in metrics_by_name:
+                    templatable_statements.append(metrics_by_name[metric])
         if self.is_rls_supported:
             templatable_statements += [
                 f.clause for f in security_manager.get_rls_filters(self)
diff --git a/tests/integration_tests/sqla_models_tests.py 
b/tests/integration_tests/sqla_models_tests.py
index 79d4bf00ed..3d723c95e2 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -913,47 +913,53 @@ def test_extra_cache_keys_in_sql_expression(
 
 @pytest.mark.usefixtures("app_context")
 @pytest.mark.parametrize(
-    "sql_expression,expected_cache_keys,has_extra_cache_keys",
+    "sql_expression,expected_cache_keys,has_extra_cache_keys,item_type",
     [
-        ("'{{ current_username() }}'", ["abc"], True),
-        ("(user != 'abc')", [], False),
+        ("'{{ current_username() }}'", ["abc"], True, "columns"),
+        ("(user != 'abc')", [], False, "columns"),
+        ("{{ current_user_id() }}", [1], True, "metrics"),
+        ("COUNT(*)", [], False, "metrics"),
     ],
 )
 @patch("superset.jinja_context.get_user_id", return_value=1)
 @patch("superset.jinja_context.get_username", return_value="abc")
-@patch("superset.jinja_context.get_user_email", return_value="[email protected]")
-def test_extra_cache_keys_in_columns(
-    mock_user_email,
+def test_extra_cache_keys_in_adhoc_metrics_and_columns(
     mock_username,
     mock_user_id,
     sql_expression,
     expected_cache_keys,
     has_extra_cache_keys,
+    item_type,
 ):
     table = SqlaTable(
         table_name="test_has_no_extra_cache_keys_table",
         sql="SELECT 'abc' as user",
         database=get_example_database(),
     )
+    base_type = "metrics" if item_type == "columns" else "columns"
     base_query_obj = {
         "granularity": None,
         "from_dttm": None,
         "to_dttm": None,
         "groupby": [],
-        "metrics": [],
+        base_type: [],
         "is_timeseries": False,
         "filter": [],
     }
 
-    query_obj = dict(
-        **base_query_obj,
-        columns=[
+    items = {
+        item_type: [
             {
                 "label": None,
                 "expressionType": "SQL",
                 "sqlExpression": sql_expression,
             }
         ],
+    }
+
+    query_obj = dict(
+        **base_query_obj,
+        **items,
     )
 
     extra_cache_keys = table.get_extra_cache_keys(query_obj)
@@ -961,6 +967,48 @@ def test_extra_cache_keys_in_columns(
     assert extra_cache_keys == expected_cache_keys
 
 
[email protected]("app_context")
+@patch("superset.jinja_context.get_user_id", return_value=1)
+@patch("superset.jinja_context.get_username", return_value="abc")
+def test_extra_cache_keys_in_dataset_metrics_and_columns(
+    mock_username,
+    mock_user_id,
+):
+    table = SqlaTable(
+        table_name="test_has_no_extra_cache_keys_table",
+        sql="SELECT 'abc' as user",
+        database=get_example_database(),
+        columns=[
+            TableColumn(column_name="user", type="VARCHAR(255)"),
+            TableColumn(
+                column_name="username",
+                type="VARCHAR(255)",
+                expression="{{ current_username() }}",
+            ),
+        ],
+        metrics=[
+            SqlMetric(
+                metric_name="variable_profit",
+                expression="SUM(price) * {{ url_param('multiplier') }}",
+            ),
+        ],
+    )
+    query_obj = {
+        "granularity": None,
+        "from_dttm": None,
+        "to_dttm": None,
+        "groupby": [],
+        "columns": ["username"],
+        "metrics": ["variable_profit"],
+        "is_timeseries": False,
+        "filter": [],
+    }
+
+    extra_cache_keys = table.get_extra_cache_keys(query_obj)
+    assert table.has_extra_cache_key_calls(query_obj) is True
+    assert set(extra_cache_keys) == {"abc", None}
+
+
 @pytest.mark.usefixtures("app_context")
 @pytest.mark.parametrize(
     "row,dimension,result",

Reply via email to