This is an automated email from the ASF dual-hosted git repository.

vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 09d3f60d85 fix(Jinja): Extra cache keys for calculated columns and 
metrics using Jinja (#30735)
09d3f60d85 is described below

commit 09d3f60d85c1e1bf5f030191ee78522b8f414705
Author: Vitor Avila <[email protected]>
AuthorDate: Tue Oct 29 10:14:27 2024 -0300

    fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja 
(#30735)
---
 superset/connectors/sqla/models.py           | 51 +++++++++++------
 tests/integration_tests/sqla_models_tests.py | 83 ++++++++++++++++++++++------
 2 files changed, 98 insertions(+), 36 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 75354ad355..fb7409adba 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -116,7 +116,6 @@ from superset.superset_typing import (
 )
 from superset.utils import core as utils, json
 from superset.utils.backports import StrEnum
-from superset.utils.core import GenericDataType, is_adhoc_column, MediumText
 
 config = app.config
 metadata = Model.metadata  # pylint: disable=no-member
@@ -477,7 +476,7 @@ class BaseDatasource(AuditMixinNullable, 
ImportExportMixin):  # pylint: disable=
         ]
 
         filtered_columns: list[Column] = []
-        column_types: set[GenericDataType] = set()
+        column_types: set[utils.GenericDataType] = set()
         for column_ in data["columns"]:
             generic_type = column_.get("type_generic")
             if generic_type is not None:
@@ -511,7 +510,7 @@ class BaseDatasource(AuditMixinNullable, 
ImportExportMixin):  # pylint: disable=
     def filter_values_handler(  # pylint: disable=too-many-arguments
         values: FilterValues | None,
         operator: str,
-        target_generic_type: GenericDataType,
+        target_generic_type: utils.GenericDataType,
         target_native_type: str | None = None,
         is_list_target: bool = False,
         db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
@@ -829,10 +828,10 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, 
CertificationMixin, Mod
     advanced_data_type = Column(String(255))
     groupby = Column(Boolean, default=True)
     filterable = Column(Boolean, default=True)
-    description = Column(MediumText())
+    description = Column(utils.MediumText())
     table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
     is_dttm = Column(Boolean, default=False)
-    expression = Column(MediumText())
+    expression = Column(utils.MediumText())
     python_date_format = Column(String(255))
     extra = Column(Text)
 
@@ -892,21 +891,21 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, 
CertificationMixin, Mod
         """
         Check if the column has a boolean datatype.
         """
-        return self.type_generic == GenericDataType.BOOLEAN
+        return self.type_generic == utils.GenericDataType.BOOLEAN
 
     @property
     def is_numeric(self) -> bool:
         """
         Check if the column has a numeric datatype.
         """
-        return self.type_generic == GenericDataType.NUMERIC
+        return self.type_generic == utils.GenericDataType.NUMERIC
 
     @property
     def is_string(self) -> bool:
         """
         Check if the column has a string datatype.
         """
-        return self.type_generic == GenericDataType.STRING
+        return self.type_generic == utils.GenericDataType.STRING
 
     @property
     def is_temporal(self) -> bool:
@@ -918,7 +917,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, 
CertificationMixin, Mod
         """
         if self.is_dttm is not None:
             return self.is_dttm
-        return self.type_generic == GenericDataType.TEMPORAL
+        return self.type_generic == utils.GenericDataType.TEMPORAL
 
     @property
     def database(self) -> Database:
@@ -935,7 +934,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, 
CertificationMixin, Mod
     @property
     def type_generic(self) -> utils.GenericDataType | None:
         if self.is_dttm:
-            return GenericDataType.TEMPORAL
+            return utils.GenericDataType.TEMPORAL
 
         return (
             column_spec.generic_type
@@ -1038,12 +1037,12 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, 
CertificationMixin, Model
     metric_name = Column(String(255), nullable=False)
     verbose_name = Column(String(1024))
     metric_type = Column(String(32))
-    description = Column(MediumText())
+    description = Column(utils.MediumText())
     d3format = Column(String(128))
     currency = Column(String(128))
     warning_text = Column(Text)
     table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
-    expression = Column(MediumText(), nullable=False)
+    expression = Column(utils.MediumText(), nullable=False)
     extra = Column(Text)
 
     table: Mapped[SqlaTable] = relationship(
@@ -1185,7 +1184,7 @@ class SqlaTable(
     )
     schema = Column(String(255))
     catalog = Column(String(256), nullable=True, default=None)
-    sql = Column(MediumText())
+    sql = Column(utils.MediumText())
     is_sqllab_view = Column(Boolean, default=False)
     template_params = Column(Text)
     extra = Column(Text)
@@ -1980,10 +1979,26 @@ class SqlaTable(
             templatable_statements.append(extras["where"])
         if "having" in extras:
             templatable_statements.append(extras["having"])
-        if "columns" in query_obj:
-            templatable_statements += [
-                c["sqlExpression"] for c in query_obj["columns"] if 
is_adhoc_column(c)
-            ]
+        if columns := query_obj.get("columns"):
+            calculated_columns: dict[str, Any] = {
+                c.column_name: c.expression for c in self.columns if 
c.expression
+            }
+            for column_ in columns:
+                if utils.is_adhoc_column(column_):
+                    templatable_statements.append(column_["sqlExpression"])
+                elif isinstance(column_, str) and column_ in 
calculated_columns:
+                    templatable_statements.append(calculated_columns[column_])
+        if metrics := query_obj.get("metrics"):
+            metrics_by_name: dict[str, Any] = {
+                m.metric_name: m.expression for m in self.metrics
+            }
+            for metric in metrics:
+                if utils.is_adhoc_metric(metric) and (
+                    sql := metric.get("sqlExpression")
+                ):
+                    templatable_statements.append(sql)
+                elif isinstance(metric, str) and metric in metrics_by_name:
+                    templatable_statements.append(metrics_by_name[metric])
         if self.is_rls_supported:
             templatable_statements += [
                 f.clause for f in security_manager.get_rls_filters(self)
@@ -2125,4 +2140,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
         secondary=RLSFilterTables,
         backref="row_level_security_filters",
     )
-    clause = Column(MediumText(), nullable=False)
+    clause = Column(utils.MediumText(), nullable=False)
diff --git a/tests/integration_tests/sqla_models_tests.py 
b/tests/integration_tests/sqla_models_tests.py
index 79d4bf00ed..d4ca3bc1c1 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -15,11 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 # isort:skip_file
+from __future__ import annotations
+
 import re
 from datetime import datetime
-from typing import Any, NamedTuple, Optional, Union
+from typing import Any, Literal, NamedTuple, Optional, Union
 from re import Pattern
-from unittest.mock import patch
+from unittest.mock import Mock, patch
 import pytest
 
 import numpy as np
@@ -913,54 +915,99 @@ def test_extra_cache_keys_in_sql_expression(
 
 @pytest.mark.usefixtures("app_context")
 @pytest.mark.parametrize(
-    "sql_expression,expected_cache_keys,has_extra_cache_keys",
+    "sql_expression,expected_cache_keys,has_extra_cache_keys,item_type",
     [
-        ("'{{ current_username() }}'", ["abc"], True),
-        ("(user != 'abc')", [], False),
+        ("'{{ current_username() }}'", ["abc"], True, "columns"),
+        ("(user != 'abc')", [], False, "columns"),
+        ("{{ current_user_id() }}", [1], True, "metrics"),
+        ("COUNT(*)", [], False, "metrics"),
     ],
 )
 @patch("superset.jinja_context.get_user_id", return_value=1)
 @patch("superset.jinja_context.get_username", return_value="abc")
-@patch("superset.jinja_context.get_user_email", return_value="[email protected]")
-def test_extra_cache_keys_in_columns(
-    mock_user_email,
-    mock_username,
-    mock_user_id,
-    sql_expression,
-    expected_cache_keys,
-    has_extra_cache_keys,
+def test_extra_cache_keys_in_adhoc_metrics_and_columns(
+    mock_username: Mock,
+    mock_user_id: Mock,
+    sql_expression: str,
+    expected_cache_keys: list[str | None],
+    has_extra_cache_keys: bool,
+    item_type: Literal["columns", "metrics"],
 ):
     table = SqlaTable(
         table_name="test_has_no_extra_cache_keys_table",
         sql="SELECT 'abc' as user",
         database=get_example_database(),
     )
-    base_query_obj = {
+    base_query_obj: dict[str, Any] = {
         "granularity": None,
         "from_dttm": None,
         "to_dttm": None,
         "groupby": [],
         "metrics": [],
+        "columns": [],
         "is_timeseries": False,
         "filter": [],
     }
 
-    query_obj = dict(
-        **base_query_obj,
-        columns=[
+    items: dict[str, Any] = {
+        item_type: [
             {
                 "label": None,
                 "expressionType": "SQL",
                 "sqlExpression": sql_expression,
             }
         ],
-    )
+    }
+
+    query_obj = {**base_query_obj, **items}
 
     extra_cache_keys = table.get_extra_cache_keys(query_obj)
     assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
     assert extra_cache_keys == expected_cache_keys
 
 
[email protected]("app_context")
+@patch("superset.jinja_context.get_user_id", return_value=1)
+@patch("superset.jinja_context.get_username", return_value="abc")
+def test_extra_cache_keys_in_dataset_metrics_and_columns(
+    mock_username: Mock,
+    mock_user_id: Mock,
+):
+    table = SqlaTable(
+        table_name="test_has_no_extra_cache_keys_table",
+        sql="SELECT 'abc' as user",
+        database=get_example_database(),
+        columns=[
+            TableColumn(column_name="user", type="VARCHAR(255)"),
+            TableColumn(
+                column_name="username",
+                type="VARCHAR(255)",
+                expression="{{ current_username() }}",
+            ),
+        ],
+        metrics=[
+            SqlMetric(
+                metric_name="variable_profit",
+                expression="SUM(price) * {{ url_param('multiplier') }}",
+            ),
+        ],
+    )
+    query_obj: dict[str, Any] = {
+        "granularity": None,
+        "from_dttm": None,
+        "to_dttm": None,
+        "groupby": [],
+        "columns": ["username"],
+        "metrics": ["variable_profit"],
+        "is_timeseries": False,
+        "filter": [],
+    }
+
+    extra_cache_keys = table.get_extra_cache_keys(query_obj)
+    assert table.has_extra_cache_key_calls(query_obj) is True
+    assert set(extra_cache_keys) == {"abc", None}
+
+
 @pytest.mark.usefixtures("app_context")
 @pytest.mark.parametrize(
     "row,dimension,result",

Reply via email to