This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 822eb2e feat(sql): add jinja support to metrics and expressions
(#15247)
822eb2e is described below
commit 822eb2e27e83238e2cd05a371b1d835f593d2907
Author: Ville Brofeldt <[email protected]>
AuthorDate: Sat Jun 19 08:29:04 2021 +0300
feat(sql): add jinja support to metrics and expressions (#15247)
* feat(sql): add jinja support to metrics and expressions
* add test
---
superset/common/query_object.py | 4 +--
superset/connectors/druid/models.py | 12 ++++++---
superset/connectors/sqla/models.py | 26 +++++++++++++++----
superset/typing.py | 16 +++++++++++-
superset/utils/core.py | 4 +--
superset/viz.py | 7 ++---
tests/sqla_models_tests.py | 51 ++++++++++++++++++++++++++++++++++++-
7 files changed, 101 insertions(+), 19 deletions(-)
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 8fd281c..77f85b4 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -17,7 +17,7 @@
# pylint: disable=R
import logging
from datetime import datetime, timedelta
-from typing import Any, Dict, List, NamedTuple, Optional, Union
+from typing import Any, Dict, List, NamedTuple, Optional
from flask_babel import gettext as _
from pandas import DataFrame
@@ -103,7 +103,7 @@ class QueryObject:
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
granularity: Optional[str] = None,
- metrics: Optional[List[Union[Dict[str, Any], str]]] = None,
+ metrics: Optional[List[Metric]] = None,
groupby: Optional[List[str]] = None,
filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
diff --git a/superset/connectors/druid/models.py
b/superset/connectors/druid/models.py
index 1839f76..14eda04 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -55,7 +55,13 @@ from superset.exceptions import SupersetException
from superset.extensions import encrypted_field_factory
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportExportMixin,
QueryResult
-from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict
+from superset.typing import (
+ AdhocMetric,
+ FilterValues,
+ Granularity,
+ Metric,
+ QueryObjectDict,
+)
from superset.utils import core as utils
from superset.utils.date_parser import parse_human_datetime,
parse_human_timedelta
@@ -1010,7 +1016,7 @@ class DruidDatasource(Model, BaseDatasource):
return ret
@staticmethod
- def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str:
+ def druid_type_from_adhoc_metric(adhoc_metric: AdhocMetric) -> str:
column_type = adhoc_metric["column"]["type"].lower()
aggregate = adhoc_metric["aggregate"].lower()
@@ -1025,7 +1031,7 @@ class DruidDatasource(Model, BaseDatasource):
def get_aggregations(
metrics_dict: Dict[str, Any],
saved_metrics: Set[str],
- adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
+ adhoc_metrics: Optional[List[AdhocMetric]] = None,
) -> "OrderedDict[str, Any]":
"""
Returns a dictionary of aggregation metric names to aggregation json
objects
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index 2f7b6d1..31f475f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -22,7 +22,18 @@ from collections import defaultdict, OrderedDict
from contextlib import closing
from dataclasses import dataclass, field # pylint: disable=wrong-import-order
from datetime import datetime, timedelta
-from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple,
Type, Union
+from typing import (
+ Any,
+ cast,
+ Dict,
+ Hashable,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
import pandas as pd
import sqlalchemy as sa
@@ -241,7 +252,9 @@ class TableColumn(Model, BaseColumn):
column_spec = db_engine_spec.get_column_spec(self.type)
type_ = column_spec.sqla_type if column_spec else None
if self.expression:
- col = literal_column(self.expression, type_=type_)
+ tp = self.table.get_template_processor()
+ expression = tp.process_template(self.expression)
+ col = literal_column(expression, type_=type_)
else:
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
@@ -879,7 +892,7 @@ class SqlaTable( # pylint:
disable=too-many-public-methods,too-many-instance-at
label = utils.get_metric_name(metric)
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
- column_name = metric["column"].get("column_name")
+ column_name = cast(str, metric["column"].get("column_name"))
table_column: Optional[TableColumn] =
columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
@@ -887,7 +900,9 @@ class SqlaTable( # pylint:
disable=too-many-public-methods,too-many-instance-at
sqla_column = column(column_name)
sqla_metric =
self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.AdhocMetricExpressionType.SQL:
- sqla_metric = literal_column(metric.get("sqlExpression"))
+ tp = self.get_template_processor()
+ expression = tp.process_template(cast(str,
metric["sqlExpression"]))
+ sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is
invalid")
@@ -1060,8 +1075,9 @@ class SqlaTable( # pylint:
disable=too-many-public-methods,too-many-instance-at
# Since orderby may use adhoc metrics, too; we need to process them
first
orderby_exprs: List[ColumnElement] = []
for orig_col, ascending in orderby:
- col: Union[Metric, ColumnElement] = orig_col
+ col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
+ col = cast(AdhocMetric, col)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
diff --git a/superset/typing.py b/superset/typing.py
index 0a7ef59..f428831 100644
--- a/superset/typing.py
+++ b/superset/typing.py
@@ -19,8 +19,23 @@ from typing import Any, Callable, Dict, List, Optional,
Sequence, Tuple, Union
from flask import Flask
from flask_caching import Cache
+from typing_extensions import TypedDict
from werkzeug.wrappers import Response
+
+class AdhocMetricColumn(TypedDict):
+ column_name: Optional[str]
+ type: str
+
+
+class AdhocMetric(TypedDict):
+ aggregate: str
+ column: AdhocMetricColumn
+ expressionType: str
+ label: str
+ sqlExpression: Optional[str]
+
+
CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]]
DbapiDescriptionRow = Tuple[
str, str, Optional[str], Optional[str], Optional[int], Optional[int], bool
@@ -31,7 +46,6 @@ FilterValue = Union[datetime, float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
FormData = Dict[str, Any]
Granularity = Union[str, Dict[str, Union[str, float]]]
-AdhocMetric = Dict[str, Any]
Metric = Union[AdhocMetric, str]
OrderBy = Tuple[Metric, bool]
QueryObjectDict = Dict[str, Any]
diff --git a/superset/utils/core.py b/superset/utils/core.py
index c8c352d..12c66e4 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -96,7 +96,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
-from superset.typing import FlaskResponse, FormData, Metric
+from superset.typing import AdhocMetric, FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
@@ -1494,7 +1494,7 @@ def get_column_name_from_metric(metric: Metric) ->
Optional[str]:
:return: column name if simple metric, otherwise None
"""
if is_adhoc_metric(metric):
- metric = cast(Dict[str, Any], metric)
+ metric = cast(AdhocMetric, metric)
if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
return cast(Dict[str, Any], metric["column"])["column_name"]
return None
diff --git a/superset/viz.py b/superset/viz.py
index 53ea2ca..04adab3 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -66,7 +66,7 @@ from superset.exceptions import (
from superset.extensions import cache_manager, security_manager
from superset.models.cache import CacheKey
from superset.models.helpers import QueryResult
-from superset.typing import QueryObjectDict, VizData, VizPayload
+from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
@@ -526,10 +526,7 @@ class BaseViz:
for col in (query_obj.get("columns") or [])
+ (query_obj.get("groupby") or [])
+ utils.get_column_names_from_metrics(
- cast(
- List[Union[str, Dict[str, Any]]],
- query_obj.get("metrics") or [],
- )
+ cast(List[Metric], query_obj.get("metrics") or [],)
)
if col not in self.datasource.column_names
]
diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py
index a759270..edd9dce 100644
--- a/tests/sqla_models_tests.py
+++ b/tests/sqla_models_tests.py
@@ -26,7 +26,12 @@ from superset.db_engine_specs.bigquery import
BigQueryEngineSpec
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
from superset.models.core import Database
-from superset.utils.core import GenericDataType, get_example_database,
FilterOperator
+from superset.utils.core import (
+ AdhocMetricExpressionType,
+ FilterOperator,
+ GenericDataType,
+ get_example_database,
+)
from tests.fixtures.birth_names_dashboard import
load_birth_names_dashboard_with_slices
from .base_tests import SupersetTestCase
@@ -168,6 +173,50 @@ class TestDatabaseModel(SupersetTestCase):
db.session.delete(table)
db.session.commit()
+ @patch("superset.jinja_context.g")
+ def test_jinja_metrics_and_calc_columns(self, flask_g):
+ flask_g.user.username = "abc"
+ base_query_obj = {
+ "granularity": None,
+ "from_dttm": None,
+ "to_dttm": None,
+ "groupby": ["user", "expr"],
+ "metrics": [
+ {
+ "expressionType": AdhocMetricExpressionType.SQL,
+ "sqlExpression": "SUM(case when user = '{{
current_username() }}' "
+ "then 1 else 0 end)",
+ "label": "SUM(userid)",
+ }
+ ],
+ "is_timeseries": False,
+ "filter": [],
+ }
+
+ table = SqlaTable(
+ table_name="test_has_jinja_metric_and_expr",
+ sql="SELECT '{{ current_username() }}' as user",
+ database=get_example_database(),
+ )
+ TableColumn(
+ column_name="expr",
+ expression="case when '{{ current_username() }}' = 'abc' "
+ "then 'yes' else 'no' end",
+ type="VARCHAR(100)",
+ table=table,
+ )
+ db.session.commit()
+
+ sqla_query = table.get_sqla_query(**base_query_obj)
+ query = table.database.compile_sqla_query(sqla_query.sqla_query)
+ # assert expression
+ assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query
+ # assert metric
+ assert "SUM(case when user = 'abc' then 1 else 0 end)" in query
+ # Cleanup
+ db.session.delete(table)
+ db.session.commit()
+
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_where_operators(self):
class FilterTestCase(NamedTuple):