This is an automated email from the ASF dual-hosted git repository.
vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 09d3f60d85 fix(Jinja): Extra cache keys for calculated columns and
metrics using Jinja (#30735)
09d3f60d85 is described below
commit 09d3f60d85c1e1bf5f030191ee78522b8f414705
Author: Vitor Avila <[email protected]>
AuthorDate: Tue Oct 29 10:14:27 2024 -0300
fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja
(#30735)
---
superset/connectors/sqla/models.py | 51 +++++++++++------
tests/integration_tests/sqla_models_tests.py | 83 ++++++++++++++++++++++------
2 files changed, 98 insertions(+), 36 deletions(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 75354ad355..fb7409adba 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -116,7 +116,6 @@ from superset.superset_typing import (
)
from superset.utils import core as utils, json
from superset.utils.backports import StrEnum
-from superset.utils.core import GenericDataType, is_adhoc_column, MediumText
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -477,7 +476,7 @@ class BaseDatasource(AuditMixinNullable,
ImportExportMixin): # pylint: disable=
]
filtered_columns: list[Column] = []
- column_types: set[GenericDataType] = set()
+ column_types: set[utils.GenericDataType] = set()
for column_ in data["columns"]:
generic_type = column_.get("type_generic")
if generic_type is not None:
@@ -511,7 +510,7 @@ class BaseDatasource(AuditMixinNullable,
ImportExportMixin): # pylint: disable=
def filter_values_handler( # pylint: disable=too-many-arguments
values: FilterValues | None,
operator: str,
- target_generic_type: GenericDataType,
+ target_generic_type: utils.GenericDataType,
target_native_type: str | None = None,
is_list_target: bool = False,
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
@@ -829,10 +828,10 @@ class TableColumn(AuditMixinNullable, ImportExportMixin,
CertificationMixin, Mod
advanced_data_type = Column(String(255))
groupby = Column(Boolean, default=True)
filterable = Column(Boolean, default=True)
- description = Column(MediumText())
+ description = Column(utils.MediumText())
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
is_dttm = Column(Boolean, default=False)
- expression = Column(MediumText())
+ expression = Column(utils.MediumText())
python_date_format = Column(String(255))
extra = Column(Text)
@@ -892,21 +891,21 @@ class TableColumn(AuditMixinNullable, ImportExportMixin,
CertificationMixin, Mod
"""
Check if the column has a boolean datatype.
"""
- return self.type_generic == GenericDataType.BOOLEAN
+ return self.type_generic == utils.GenericDataType.BOOLEAN
@property
def is_numeric(self) -> bool:
"""
Check if the column has a numeric datatype.
"""
- return self.type_generic == GenericDataType.NUMERIC
+ return self.type_generic == utils.GenericDataType.NUMERIC
@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
- return self.type_generic == GenericDataType.STRING
+ return self.type_generic == utils.GenericDataType.STRING
@property
def is_temporal(self) -> bool:
@@ -918,7 +917,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin,
CertificationMixin, Mod
"""
if self.is_dttm is not None:
return self.is_dttm
- return self.type_generic == GenericDataType.TEMPORAL
+ return self.type_generic == utils.GenericDataType.TEMPORAL
@property
def database(self) -> Database:
@@ -935,7 +934,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin,
CertificationMixin, Mod
@property
def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
- return GenericDataType.TEMPORAL
+ return utils.GenericDataType.TEMPORAL
return (
column_spec.generic_type
@@ -1038,12 +1037,12 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin,
CertificationMixin, Model
metric_name = Column(String(255), nullable=False)
verbose_name = Column(String(1024))
metric_type = Column(String(32))
- description = Column(MediumText())
+ description = Column(utils.MediumText())
d3format = Column(String(128))
currency = Column(String(128))
warning_text = Column(Text)
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
- expression = Column(MediumText(), nullable=False)
+ expression = Column(utils.MediumText(), nullable=False)
extra = Column(Text)
table: Mapped[SqlaTable] = relationship(
@@ -1185,7 +1184,7 @@ class SqlaTable(
)
schema = Column(String(255))
catalog = Column(String(256), nullable=True, default=None)
- sql = Column(MediumText())
+ sql = Column(utils.MediumText())
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
extra = Column(Text)
@@ -1980,10 +1979,26 @@ class SqlaTable(
templatable_statements.append(extras["where"])
if "having" in extras:
templatable_statements.append(extras["having"])
- if "columns" in query_obj:
- templatable_statements += [
- c["sqlExpression"] for c in query_obj["columns"] if
is_adhoc_column(c)
- ]
+ if columns := query_obj.get("columns"):
+ calculated_columns: dict[str, Any] = {
+ c.column_name: c.expression for c in self.columns if
c.expression
+ }
+ for column_ in columns:
+ if utils.is_adhoc_column(column_):
+ templatable_statements.append(column_["sqlExpression"])
+ elif isinstance(column_, str) and column_ in
calculated_columns:
+ templatable_statements.append(calculated_columns[column_])
+ if metrics := query_obj.get("metrics"):
+ metrics_by_name: dict[str, Any] = {
+ m.metric_name: m.expression for m in self.metrics
+ }
+ for metric in metrics:
+ if utils.is_adhoc_metric(metric) and (
+ sql := metric.get("sqlExpression")
+ ):
+ templatable_statements.append(sql)
+ elif isinstance(metric, str) and metric in metrics_by_name:
+ templatable_statements.append(metrics_by_name[metric])
if self.is_rls_supported:
templatable_statements += [
f.clause for f in security_manager.get_rls_filters(self)
@@ -2125,4 +2140,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
secondary=RLSFilterTables,
backref="row_level_security_filters",
)
- clause = Column(MediumText(), nullable=False)
+ clause = Column(utils.MediumText(), nullable=False)
diff --git a/tests/integration_tests/sqla_models_tests.py
b/tests/integration_tests/sqla_models_tests.py
index 79d4bf00ed..d4ca3bc1c1 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
+from __future__ import annotations
+
import re
from datetime import datetime
-from typing import Any, NamedTuple, Optional, Union
+from typing import Any, Literal, NamedTuple, Optional, Union
from re import Pattern
-from unittest.mock import patch
+from unittest.mock import Mock, patch
import pytest
import numpy as np
@@ -913,54 +915,99 @@ def test_extra_cache_keys_in_sql_expression(
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
- "sql_expression,expected_cache_keys,has_extra_cache_keys",
+ "sql_expression,expected_cache_keys,has_extra_cache_keys,item_type",
[
- ("'{{ current_username() }}'", ["abc"], True),
- ("(user != 'abc')", [], False),
+ ("'{{ current_username() }}'", ["abc"], True, "columns"),
+ ("(user != 'abc')", [], False, "columns"),
+ ("{{ current_user_id() }}", [1], True, "metrics"),
+ ("COUNT(*)", [], False, "metrics"),
],
)
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
-@patch("superset.jinja_context.get_user_email", return_value="[email protected]")
-def test_extra_cache_keys_in_columns(
- mock_user_email,
- mock_username,
- mock_user_id,
- sql_expression,
- expected_cache_keys,
- has_extra_cache_keys,
+def test_extra_cache_keys_in_adhoc_metrics_and_columns(
+ mock_username: Mock,
+ mock_user_id: Mock,
+ sql_expression: str,
+ expected_cache_keys: list[str | None],
+ has_extra_cache_keys: bool,
+ item_type: Literal["columns", "metrics"],
):
table = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql="SELECT 'abc' as user",
database=get_example_database(),
)
- base_query_obj = {
+ base_query_obj: dict[str, Any] = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": [],
"metrics": [],
+ "columns": [],
"is_timeseries": False,
"filter": [],
}
- query_obj = dict(
- **base_query_obj,
- columns=[
+ items: dict[str, Any] = {
+ item_type: [
{
"label": None,
"expressionType": "SQL",
"sqlExpression": sql_expression,
}
],
- )
+ }
+
+ query_obj = {**base_query_obj, **items}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
assert extra_cache_keys == expected_cache_keys
[email protected]("app_context")
+@patch("superset.jinja_context.get_user_id", return_value=1)
+@patch("superset.jinja_context.get_username", return_value="abc")
+def test_extra_cache_keys_in_dataset_metrics_and_columns(
+ mock_username: Mock,
+ mock_user_id: Mock,
+):
+ table = SqlaTable(
+ table_name="test_has_no_extra_cache_keys_table",
+ sql="SELECT 'abc' as user",
+ database=get_example_database(),
+ columns=[
+ TableColumn(column_name="user", type="VARCHAR(255)"),
+ TableColumn(
+ column_name="username",
+ type="VARCHAR(255)",
+ expression="{{ current_username() }}",
+ ),
+ ],
+ metrics=[
+ SqlMetric(
+ metric_name="variable_profit",
+ expression="SUM(price) * {{ url_param('multiplier') }}",
+ ),
+ ],
+ )
+ query_obj: dict[str, Any] = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": [],
+ "columns": ["username"],
+ "metrics": ["variable_profit"],
+ "is_timeseries": False,
+ "filter": [],
+ }
+
+ extra_cache_keys = table.get_extra_cache_keys(query_obj)
+ assert table.has_extra_cache_key_calls(query_obj) is True
+ assert set(extra_cache_keys) == {"abc", None}
+
+
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"row,dimension,result",