This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 822eb2e  feat(sql): add jinja support to metrics and expressions 
(#15247)
822eb2e is described below

commit 822eb2e27e83238e2cd05a371b1d835f593d2907
Author: Ville Brofeldt <[email protected]>
AuthorDate: Sat Jun 19 08:29:04 2021 +0300

    feat(sql): add jinja support to metrics and expressions (#15247)
    
    * feat(sql): add jinja support to metrics and expressions
    
    * add test
---
 superset/common/query_object.py     |  4 +--
 superset/connectors/druid/models.py | 12 ++++++---
 superset/connectors/sqla/models.py  | 26 +++++++++++++++----
 superset/typing.py                  | 16 +++++++++++-
 superset/utils/core.py              |  4 +--
 superset/viz.py                     |  7 ++---
 tests/sqla_models_tests.py          | 51 ++++++++++++++++++++++++++++++++++++-
 7 files changed, 101 insertions(+), 19 deletions(-)

diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 8fd281c..77f85b4 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -17,7 +17,7 @@
 # pylint: disable=R
 import logging
 from datetime import datetime, timedelta
-from typing import Any, Dict, List, NamedTuple, Optional, Union
+from typing import Any, Dict, List, NamedTuple, Optional
 
 from flask_babel import gettext as _
 from pandas import DataFrame
@@ -103,7 +103,7 @@ class QueryObject:
         applied_time_extras: Optional[Dict[str, str]] = None,
         apply_fetch_values_predicate: bool = False,
         granularity: Optional[str] = None,
-        metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
+        metrics: Optional[List[Metric]] = None,
         groupby: Optional[List[str]] = None,
         filters: Optional[List[Dict[str, Any]]] = None,
         time_range: Optional[str] = None,
diff --git a/superset/connectors/druid/models.py 
b/superset/connectors/druid/models.py
index 1839f76..14eda04 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -55,7 +55,13 @@ from superset.exceptions import SupersetException
 from superset.extensions import encrypted_field_factory
 from superset.models.core import Database
 from superset.models.helpers import AuditMixinNullable, ImportExportMixin, 
QueryResult
-from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
+from superset.typing import (
+    AdhocMetric,
+    FilterValues,
+    Granularity,
+    Metric,
+    QueryObjectDict,
+)
 from superset.utils import core as utils
 from superset.utils.date_parser import parse_human_datetime, 
parse_human_timedelta
 
@@ -1010,7 +1016,7 @@ class DruidDatasource(Model, BaseDatasource):
         return ret
 
     @staticmethod
-    def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str:
+    def druid_type_from_adhoc_metric(adhoc_metric: AdhocMetric) -> str:
         column_type = adhoc_metric["column"]["type"].lower()
         aggregate = adhoc_metric["aggregate"].lower()
 
@@ -1025,7 +1031,7 @@ class DruidDatasource(Model, BaseDatasource):
     def get_aggregations(
         metrics_dict: Dict[str, Any],
         saved_metrics: Set[str],
-        adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
+        adhoc_metrics: Optional[List[AdhocMetric]] = None,
     ) -> "OrderedDict[str, Any]":
         """
         Returns a dictionary of aggregation metric names to aggregation json 
objects
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 2f7b6d1..31f475f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -22,7 +22,18 @@ from collections import defaultdict, OrderedDict
 from contextlib import closing
 from dataclasses import dataclass, field  # pylint: disable=wrong-import-order
 from datetime import datetime, timedelta
-from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, 
Type, Union
+from typing import (
+    Any,
+    cast,
+    Dict,
+    Hashable,
+    List,
+    NamedTuple,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 
 import pandas as pd
 import sqlalchemy as sa
@@ -241,7 +252,9 @@ class TableColumn(Model, BaseColumn):
         column_spec = db_engine_spec.get_column_spec(self.type)
         type_ = column_spec.sqla_type if column_spec else None
         if self.expression:
-            col = literal_column(self.expression, type_=type_)
+            tp = self.table.get_template_processor()
+            expression = tp.process_template(self.expression)
+            col = literal_column(expression, type_=type_)
         else:
             col = column(self.column_name, type_=type_)
         col = self.table.make_sqla_column_compatible(col, label)
@@ -879,7 +892,7 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         label = utils.get_metric_name(metric)
 
         if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
-            column_name = metric["column"].get("column_name")
+            column_name = cast(str, metric["column"].get("column_name"))
             table_column: Optional[TableColumn] = 
columns_by_name.get(column_name)
             if table_column:
                 sqla_column = table_column.get_sqla_col()
@@ -887,7 +900,9 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
                 sqla_column = column(column_name)
             sqla_metric = 
self.sqla_aggregations[metric["aggregate"]](sqla_column)
         elif expression_type == utils.AdhocMetricExpressionType.SQL:
-            sqla_metric = literal_column(metric.get("sqlExpression"))
+            tp = self.get_template_processor()
+            expression = tp.process_template(cast(str, 
metric["sqlExpression"]))
+            sqla_metric = literal_column(expression)
         else:
             raise QueryObjectValidationError("Adhoc metric expressionType is 
invalid")
 
@@ -1060,8 +1075,9 @@ class SqlaTable(  # pylint: 
disable=too-many-public-methods,too-many-instance-at
         # Since orderby may use adhoc metrics, too; we need to process them 
first
         orderby_exprs: List[ColumnElement] = []
         for orig_col, ascending in orderby:
-            col: Union[Metric, ColumnElement] = orig_col
+            col: Union[AdhocMetric, ColumnElement] = orig_col
             if isinstance(col, dict):
+                col = cast(AdhocMetric, col)
                 if utils.is_adhoc_metric(col):
                     # add adhoc sort by column to columns_by_name if not exists
                     col = self.adhoc_metric_to_sqla(col, columns_by_name)
diff --git a/superset/typing.py b/superset/typing.py
index 0a7ef59..f428831 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -19,8 +19,23 @@ from typing import Any, Callable, Dict, List, Optional, 
Sequence, Tuple, Union
 
 from flask import Flask
 from flask_caching import Cache
+from typing_extensions import TypedDict
 from werkzeug.wrappers import Response
 
+
+class AdhocMetricColumn(TypedDict):
+    column_name: Optional[str]
+    type: str
+
+
+class AdhocMetric(TypedDict):
+    aggregate: str
+    column: AdhocMetricColumn
+    expressionType: str
+    label: str
+    sqlExpression: Optional[str]
+
+
 CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]]
 DbapiDescriptionRow = Tuple[
     str, str, Optional[str], Optional[str], Optional[int], Optional[int], bool
@@ -31,7 +46,6 @@ FilterValue = Union[datetime, float, int, str]
 FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
 FormData = Dict[str, Any]
 Granularity = Union[str, Dict[str, Union[str, float]]]
-AdhocMetric = Dict[str, Any]
 Metric = Union[AdhocMetric, str]
 OrderBy = Tuple[Metric, bool]
 QueryObjectDict = Dict[str, Any]
diff --git a/superset/utils/core.py b/superset/utils/core.py
index c8c352d..12c66e4 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -96,7 +96,7 @@ from superset.exceptions import (
     SupersetException,
     SupersetTimeoutException,
 )
-from superset.typing import FlaskResponse, FormData, Metric
+from superset.typing import AdhocMetric, FlaskResponse, FormData, Metric
 from superset.utils.dates import datetime_to_epoch, EPOCH
 from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
 
@@ -1494,7 +1494,7 @@ def get_column_name_from_metric(metric: Metric) -> 
Optional[str]:
     :return: column name if simple metric, otherwise None
     """
     if is_adhoc_metric(metric):
-        metric = cast(Dict[str, Any], metric)
+        metric = cast(AdhocMetric, metric)
         if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
             return cast(Dict[str, Any], metric["column"])["column_name"]
     return None
diff --git a/superset/viz.py b/superset/viz.py
index 53ea2ca..04adab3 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -66,7 +66,7 @@ from superset.exceptions import (
 from superset.extensions import cache_manager, security_manager
 from superset.models.cache import CacheKey
 from superset.models.helpers import QueryResult
-from superset.typing import QueryObjectDict, VizData, VizPayload
+from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
 from superset.utils import core as utils, csv
 from superset.utils.cache import set_and_log_cache
 from superset.utils.core import (
@@ -526,10 +526,7 @@ class BaseViz:
                     for col in (query_obj.get("columns") or [])
                     + (query_obj.get("groupby") or [])
                     + utils.get_column_names_from_metrics(
-                        cast(
-                            List[Union[str, Dict[str, Any]]],
-                            query_obj.get("metrics") or [],
-                        )
+                        cast(List[Metric], query_obj.get("metrics") or [],)
                     )
                     if col not in self.datasource.column_names
                 ]
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index a759270..edd9dce 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -26,7 +26,12 @@ from superset.db_engine_specs.bigquery import 
BigQueryEngineSpec
 from superset.db_engine_specs.druid import DruidEngineSpec
 from superset.exceptions import QueryObjectValidationError
 from superset.models.core import Database
-from superset.utils.core import GenericDataType, get_example_database, 
FilterOperator
+from superset.utils.core import (
+    AdhocMetricExpressionType,
+    FilterOperator,
+    GenericDataType,
+    get_example_database,
+)
 from tests.fixtures.birth_names_dashboard import 
load_birth_names_dashboard_with_slices
 
 from .base_tests import SupersetTestCase
@@ -168,6 +173,50 @@ class TestDatabaseModel(SupersetTestCase):
             db.session.delete(table)
         db.session.commit()
 
+    @patch("superset.jinja_context.g")
+    def test_jinja_metrics_and_calc_columns(self, flask_g):
+        flask_g.user.username = "abc"
+        base_query_obj = {
+            "granularity": None,
+            "from_dttm": None,
+            "to_dttm": None,
+            "groupby": ["user", "expr"],
+            "metrics": [
+                {
+                    "expressionType": AdhocMetricExpressionType.SQL,
+                    "sqlExpression": "SUM(case when user = '{{ 
current_username() }}' "
+                    "then 1 else 0 end)",
+                    "label": "SUM(userid)",
+                }
+            ],
+            "is_timeseries": False,
+            "filter": [],
+        }
+
+        table = SqlaTable(
+            table_name="test_has_jinja_metric_and_expr",
+            sql="SELECT '{{ current_username() }}' as user",
+            database=get_example_database(),
+        )
+        TableColumn(
+            column_name="expr",
+            expression="case when '{{ current_username() }}' = 'abc' "
+            "then 'yes' else 'no' end",
+            type="VARCHAR(100)",
+            table=table,
+        )
+        db.session.commit()
+
+        sqla_query = table.get_sqla_query(**base_query_obj)
+        query = table.database.compile_sqla_query(sqla_query.sqla_query)
+        # assert expression
+        assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
+        # assert metric
+        assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
+        # Cleanup
+        db.session.delete(table)
+        db.session.commit()
+
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
     def test_where_operators(self):
         class FilterTestCase(NamedTuple):

Reply via email to