This is an automated email from the ASF dual-hosted git repository. lyndsi pushed a commit to branch lyndsi/sql-lab-new-explore-button-functionality-and-move-save-dataset-to-split-save-button in repository https://gitbox.apache.org/repos/asf/superset.git
commit f992a027b5db71485fea86bc5d7d7f67e8c56cb6 Author: Hugh A. Miles II <[email protected]> AuthorDate: Mon Jun 6 15:09:01 2022 +0000 add POC ExploreMixin --- superset/models/helpers.py | 673 ++++++++++++++++----------------------------- 1 file changed, 234 insertions(+), 439 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 6cd6d17c7b..08efb59b60 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -24,21 +24,19 @@ from datetime import datetime, timedelta from json.decoder import JSONDecodeError from typing import ( Any, + Callable, cast, Dict, + Hashable, List, - Mapping, NamedTuple, Optional, Set, - Text, Tuple, Type, - TYPE_CHECKING, Union, ) -import dateutil.parser import humanize import numpy as np import pandas as pd @@ -52,48 +50,33 @@ from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ -from jinja2.exceptions import TemplateError -from sqlalchemy import and_, Column, or_, UniqueConstraint +from sqlalchemy import and_, or_, UniqueConstraint from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapper, Session from sqlalchemy.orm.exc import MultipleResultsFound -from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause +from sqlalchemy.sql.elements import ColumnClause, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType -from superset import app, is_feature_enabled, security_manager -from superset.advanced_data_type.types import AdvancedDataTypeResponse +from superset import app, db, is_feature_enabled, security_manager from superset.common.db_query_status import QueryStatus -from superset.constants import EMPTY_STRING, NULL_STRING -from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import ( - AdvancedDataTypeResponseError, - QueryClauseValidationException, - QueryObjectValidationError, - SupersetSecurityException, +from superset.jinja_context import ( + BaseTemplateProcessor, + ExtraCache, + get_template_processor, ) -from superset.extensions import feature_flag_manager -from superset.jinja_context import BaseTemplateProcessor -from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, sanitize_clause -from superset.superset_typing import ( - AdhocMetric, - FilterValue, - FilterValues, - Metric, - OrderBy, - QueryObjectDict, +from superset.sql_parse import ( + extract_table_references, + ParsedQuery, + sanitize_clause, + Table as TableName, ) from superset.utils import core as utils -from superset.utils.core import get_user_id -if TYPE_CHECKING: - from superset.connectors.sqla.models import SqlMetric, TableColumn - from superset.db_engine_specs import BaseEngineSpec - from superset.models.core import Database +VIRTUAL_TABLE_ALIAS = "virtual_table" -config = app.config logger = logging.getLogger(__name__) CTE_ALIAS = "__cte" @@ -634,6 +617,17 @@ def clone_model( return target.__class__(**data) +from typing import Any, Dict, List, NamedTuple + +import sqlparse +from sqlalchemy import Column +from sqlalchemy.sql.elements import ColumnElement, Label, literal_column + +from superset.exceptions import QueryObjectValidationError +from superset.superset_typing import AdhocMetric, Metric, OrderBy, QueryObjectDict +from superset.utils import core as utils + + # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): applied_template_filters: Optional[List[str]] @@ -651,113 +645,63 @@ class SqlaQuery(NamedTuple): sqla_query: Select -class ExploreMixin: # pylint: disable=too-many-public-methods +class ExploreMixin: """ Allows any flask_appbuilder.Model (Query, Table, etc.) to be used to power a chart inside /explore """ - sqla_aggregations = { - "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)), - "COUNT": sa.func.COUNT, - "SUM": sa.func.SUM, - "AVG": sa.func.AVG, - "MIN": sa.func.MIN, - "MAX": sa.func.MAX, - } - @property - def query(self) -> str: - raise NotImplementedError() + def data(self): + return {"foo": "bar"} @property - def database_id(self) -> int: - raise NotImplementedError() + def owners_data(self): + return [] @property - def owners_data(self) -> List[Any]: - raise NotImplementedError() + def metrics(self): + return [] @property - def metrics(self) -> List[Any]: - raise NotImplementedError() + def uid(self): + return "foo" @property - def uid(self) -> str: - raise NotImplementedError() + def is_rls_supported(self): + return False @property - def is_rls_supported(self) -> bool: - raise NotImplementedError() + def cache_timeout(self): + return None @property - def cache_timeout(self) -> int: - raise NotImplementedError() + def column_names(self): + return ["ethnic_minority", "gender"] @property - def column_names(self) -> List[str]: - raise NotImplementedError() + def columns(self): + return ["<col_name>"] @property - def offset(self) -> int: - raise NotImplementedError() + def offset(self): + return 0 @property - def main_dttm_col(self) -> Optional[str]: - raise NotImplementedError() + def main_dttm_col(self) -> str: # todo - this should be a real column + return "ds" @property def dttm_cols(self) -> List[str]: - raise NotImplementedError() - - @property - def db_engine_spec(self) -> Type["BaseEngineSpec"]: - raise NotImplementedError() - - @property - def database(self) -> Type["Database"]: - raise NotImplementedError() - - @property - def schema(self) -> str: - raise NotImplementedError() - - @property - def sql(self) -> str: - raise NotImplementedError() - - @property - def columns(self) -> List[Any]: - raise NotImplementedError() - - @property - def get_fetch_values_predicate(self) -> List[Any]: - raise NotImplementedError() + return [] + # l = [c.column_name for c in self.columns if c.is_dttm] + # if self.main_dttm_col and self.main_dttm_col not in l: + # l.append(self.main_dttm_col) + # return l @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: - raise NotImplementedError() - - def _process_sql_expression( # type: ignore # pylint: disable=no-self-use - self, - expression: Optional[str], - database_id: int, - schema: str, - template_processor: Optional[BaseTemplateProcessor], - ) -> Optional[str]: - if template_processor and expression: - expression = template_processor.process_template(expression) - if expression: - expression = validate_adhoc_subquery( - expression, - database_id, - schema, - ) - try: - expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex - return expression + def get_extra_cache_keys(query_obj): + return [] def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None @@ -776,72 +720,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods sqla_col.key = label_expected return sqla_col - def mutate_query_from_config(self, sql: str) -> str: - """Apply config's SQL_QUERY_MUTATOR - - Typically adds comments to the query with context""" - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - if sql_query_mutator: - sql = sql_query_mutator( - sql, - user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. - security_manager=security_manager, - database=self.database, - ) - return sql - - @staticmethod - def _apply_cte(sql: str, cte: Optional[str]) -> str: - """ - Append a CTE before the SELECT statement if defined - - :param sql: SELECT statement - :param cte: CTE statement - :return: - """ - if cte: - sql = f"{cte}\n{sql}" - return sql - - @staticmethod - def validate_adhoc_subquery( - sql: str, - database_id: int, - default_schema: str, - ) -> str: - """ - Check if adhoc SQL contains sub-queries or nested sub-queries with table. - - If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS - predicates to it. - - :param sql: adhoc sql expression - :raise SupersetSecurityException if sql contains sub-queries or - nested sub-queries with table - """ - - statements = [] - for statement in sqlparse.parse(sql): - if has_table_query(statement): - if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): - raise SupersetSecurityException( - SupersetError( - error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, - message=_("Custom SQL fields cannot contain sub-queries."), - level=ErrorLevel.ERROR, - ) - ) - statement = insert_rls(statement, database_id, default_schema) - statements.append(statement) - - return ";\n".join(str(statement) for statement in statements) - def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) - sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore - sql = self._apply_cte(sql, sqlaq.cte) + sql = self.database.compile_sqla_query(sqlaq.sqla_query) + # sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) - sql = self.mutate_query_from_config(sql) + # sql = self.mutate_query_from_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, labels_expected=sqlaq.labels_expected, @@ -849,43 +733,6 @@ class ExploreMixin: # pylint: disable=too-many-public-methods sql=sql, ) - def _normalize_prequery_result_type( - self, - row: pd.Series, - dimension: str, - columns_by_name: Dict[str, "TableColumn"], - ) -> Union[str, int, float, bool, Text]: - """ - Convert a prequery result type to its equivalent Python type. - - Some databases like Druid will return timestamps as strings, but do not perform - automatic casting when comparing these strings to a timestamp. For cases like - this we convert the value via the appropriate SQL transform. - - :param row: A prequery record - :param dimension: The dimension name - :param columns_by_name: The mapping of columns by name - :return: equivalent primitive python type - """ - - value = row[dimension] - - if isinstance(value, np.generic): - value = value.item() - - column_ = columns_by_name[dimension] - db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore - - if column_.type and column_.is_temporal and isinstance(value, str): - sql = self.db_engine_spec.convert_dttm( - column_.type, dateutil.parser.parse(value), db_extra=db_extra - ) - - if sql: - value = self.text(sql) - - return value - def make_orderby_compatible( self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] ) -> None: @@ -917,48 +764,56 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def exc_query(self, qry: Any) -> QueryResult: qry_start_dttm = datetime.now() + # todo(hugh): apply filters for extended query query_str_ext = self.get_query_str_extended(qry) sql = query_str_ext.sql status = QueryStatus.SUCCESS errors = None error_message = None - def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: - """ - Some engines change the case or generate bespoke column names, either by - default or due to lack of support for aliasing. This function ensures that - the column names in the DataFrame correspond to what is expected by - the viz components. - Sometimes a query may also contain only order by columns that are not used - as metrics or groupby columns, but need to present in the SQL `select`, - filtering by `labels_expected` make sure we only return columns users want. - :param df: Original DataFrame returned by the engine - :return: Mutated DataFrame - """ - labels_expected = query_str_ext.labels_expected - if df is not None and not df.empty: - if len(df.columns) < len(labels_expected): - raise QueryObjectValidationError( - _("Db engine did not return all queried columns") - ) - if len(df.columns) > len(labels_expected): - df = df.iloc[:, 0 : len(labels_expected)] - df.columns = labels_expected - return df + # def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: + # """ + # Some engines change the case or generate bespoke column names, either by + # default or due to lack of support for aliasing. This function ensures that + # the column names in the DataFrame correspond to what is expected by + # the viz components. + # Sometimes a query may also contain only order by columns that are not used + # as metrics or groupby columns, but need to present in the SQL `select`, + # filtering by `labels_expected` make sure we only return columns users want. + # :param df: Original DataFrame returned by the engine + # :return: Mutated DataFrame + # """ + # labels_expected = query_str_ext.labels_expected + # if df is not None and not df.empty: + # if len(df.columns) < len(labels_expected): + # raise QueryObjectValidationError( + # _("Db engine did not return all queried columns") + # ) + # if len(df.columns) > len(labels_expected): + # df = df.iloc[:, 0: len(labels_expected)] + # df.columns = labels_expected + # return df try: - df = self.database.get_df( - sql, self.schema, mutator=assign_column_label # type: ignore - ) + # todo(hugh) fix this + # df = self.database.get_df( + # sql, self.schema, mutator=assign_column_label) + df = self.database.get_df(sql, self.schema) except Exception as ex: # pylint: disable=broad-except df = pd.DataFrame() status = QueryStatus.FAILED logger.warning( "Query %s on schema %s failed", sql, self.schema, exc_info=True ) + # todo(hugh): how are we handling errors + # db_engine_spec = self.db_engine_spec + # errors = [ + # dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex) + # ] error_message = utils.error_msg_from_exception(ex) return QueryResult( + # applied_template_filters=query_str_ext.applied_template_filters, status=status, df=df, duration=datetime.now() - qry_start_dttm, @@ -1005,6 +860,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods or a virtual table with it's own subquery. If the FROM is referencing a CTE, the CTE is returned as the second value in the return tuple. """ + # todo(hugh): fix this + # if not self.is_virtual: + # return self.get_sqla_table(), None from_sql = self.get_rendered_sql(template_processor) parsed_query = ParsedQuery(from_sql) @@ -1018,49 +876,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods cte = self.db_engine_spec.get_cte_query(from_sql) from_clause = ( - sa.table(CTE_ALIAS) + table(CTE_ALIAS) if cte else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) ) return from_clause, cte - def adhoc_metric_to_sqla( - self, - metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # # pylint: disable=unused-argument - template_processor: Optional[BaseTemplateProcessor] = None, - ) -> ColumnElement: - """ - Turn an adhoc metric into a sqlalchemy column. - - :param dict metric: Adhoc metric definition - :param dict columns_by_name: Columns for the current table - :param template_processor: template_processor instance - :returns: The metric defined as a sqlalchemy column - :rtype: sqlalchemy.sql.column - """ - expression_type = metric.get("expressionType") - label = utils.get_metric_name(metric) - - if expression_type == utils.AdhocMetricExpressionType.SIMPLE: - metric_column = metric.get("column") or {} - column_name = cast(str, metric_column.get("column_name")) - sqla_column = sa.column(column_name) - sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) - elif expression_type == utils.AdhocMetricExpressionType.SQL: - expression = self._process_sql_expression( # type: ignore - expression=metric["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - sqla_metric = literal_column(expression) - else: - raise QueryObjectValidationError("Adhoc metric expressionType is invalid") - - return self.make_sqla_column_compatible(sqla_metric, label) - @property def template_params_dict(self) -> Dict[Any, Any]: return {} @@ -1247,7 +1069,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods row_offset: Optional[int] = None, timeseries_limit: Optional[int] = None, timeseries_limit_metric: Optional[Metric] = None, - ) -> SqlaQuery: + ) -> Any: """Querying any sqla table from this common interface""" if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col @@ -1265,7 +1087,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.get("column_name") for col in self.columns], + "table_columns": self.column_names, "filter": filter, } columns = columns or [] @@ -1275,14 +1097,16 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if is_timeseries and timeseries_limit: series_limit = timeseries_limit series_limit_metric = series_limit_metric or timeseries_limit_metric - template_kwargs.update(self.template_params_dict) + template_kwargs.update(self.template_params_dict) # todo extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys removed_filters: List[str] = [] applied_template_filters: List[str] = [] template_kwargs["removed_filters"] = removed_filters template_kwargs["applied_filters"] = applied_template_filters - template_processor = None # self.get_template_processor(**template_kwargs) + template_processor = ( + None # self.get_template_processor(**template_kwargs) #todo + ) db_engine_spec = self.db_engine_spec prequeries: List[str] = [] orderby = orderby or [] @@ -1293,36 +1117,40 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { - col.get("column_name"): col - for col in self.columns # col.column_name: col for col in self.columns - } - - if not granularity and is_timeseries: - raise QueryObjectValidationError( - _( - "Datetime column not provided as part table configuration " - "and is required by this type of chart" - ) - ) - if not metrics and not columns and not groupby: - raise QueryObjectValidationError(_("Empty query?")) + # columns_by_name: Dict[str, sa.Table] = { + # col.column_name: col for col in self.columns + # } + # todo(hugh): fix this + columns_by_name = {} + + # todo(hugh): how are we handling metrics + # metrics_by_name: Dict[str, Column] = { # todo column vs metric? + # m.metric_name: m for m in self.metrics + # } + metrics_by_name: Dict[str, Column] = {} + + # if not granularity and is_timeseries: + # raise QueryObjectValidationError( + # _( + # "Datetime column not provided as part table configuration " + # "and is required by this type of chart" + # ) + # ) + # if not metrics and not columns and not groupby: + # raise QueryObjectValidationError(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] - for metric in metrics: - if utils.is_adhoc_metric(metric): - assert isinstance(metric, dict) - metrics_exprs.append( - self.adhoc_metric_to_sqla( - metric=metric, - columns_by_name=columns_by_name, # type: ignore - template_processor=template_processor, - ) - ) - else: - raise QueryObjectValidationError( - _("Metric '%(metric)s' does not exist", metric=metric) - ) + # for metric in metrics: + # if utils.is_adhoc_metric(metric): + # assert isinstance(metric, dict) + # # metrics_exprs.append( + # # self.adhoc_metric_to_sqla(metric, columns_by_name)) + # elif isinstance(metric, str) and metric in metrics_by_name: + # metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) + # else: + # raise QueryObjectValidationError( + # _("Metric '%(metric)s' does not exist", metric=metric) + # ) if metrics_exprs: main_metric_expr = metrics_exprs[0] @@ -1342,16 +1170,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): col = cast(AdhocMetric, col) - if col.get("sqlExpression"): - col["sqlExpression"] = self._process_sql_expression( # type: ignore - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) # type: ignore + # todo(hugh): figure out if we should have metrics + # col = self.adhoc_metric_to_sqla(col, columns_by_name) # if the adhoc metric has been defined before # use the existing instance. col = metrics_exprs_by_expr.get(str(col), col) @@ -1361,14 +1183,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True + elif col in metrics_by_name: + col = metrics_by_name[col].get_sqla_col() + need_groupby = True - if isinstance(col, ColumnElement): - orderby_exprs.append(col) - else: - # Could not convert a column reference to valid ColumnElement - raise QueryObjectValidationError( - _("Unknown column used in orderby: %(col)s", col=orig_col) - ) + # todo(hugh): fix this + # if isinstance(col, ColumnElement): + # orderby_exprs.append(col) + # else: + # # Could not convert a column reference to valid ColumnElement + # raise QueryObjectValidationError( + # _("Unknown column used in orderby: %(col)s", col=orig_col) + # ) select_exprs: List[Union[Column, Label]] = [] groupby_all_columns = {} @@ -1389,26 +1215,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods outer = table_col.get_timestamp_expression( time_grain=time_grain, label=selected, - template_processor=template_processor, + # template_processor=template_processor, ) # if groupby field equals a selected column elif selected in columns_by_name: - if isinstance(columns_by_name[selected], dict): - outer = sa.column(f"{selected}") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = columns_by_name[selected].get_sqla_col() + outer = columns_by_name[selected].get_sqla_col() else: - selected = self.validate_adhoc_subquery( - selected, - self.database_id, - self.schema, - ) - outer = sa.column(f"{selected}") + outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor + col=selected, # template_processor=template_processor ) groupby_all_columns[outer.name] = outer if not series_column_names or outer.name in series_column_names: @@ -1416,38 +1233,45 @@ class ExploreMixin: # pylint: disable=too-many-public-methods select_exprs.append(outer) elif columns: for selected in columns: - selected = self.validate_adhoc_subquery( - selected, - self.database_id, - self.schema, + select_exprs.append( + columns_by_name[selected].get_sqla_col() + if selected in columns_by_name + else self.make_sqla_column_compatible(literal_column(selected)) ) - if isinstance(columns_by_name[selected], dict): - select_exprs.append(sa.column(f"{selected}")) - else: - select_exprs.append( - columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) - ) metrics_exprs = [] - if granularity: - if granularity not in columns_by_name or not dttm_col: - raise QueryObjectValidationError( - _( - 'Time column "%(col)s" does not exist in dataset', - col=granularity, - ) - ) - time_filters: List[Any] = [] - - if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) - # always put timestamp as the first column - select_exprs.insert(0, timestamp) - groupby_all_columns[timestamp.name] = timestamp + # todo(hugh): fix this + # if granularity: + # if granularity not in columns_by_name or not dttm_col: + # raise QueryObjectValidationError( + # _( + # 'Time column "%(col)s" does not exist in dataset', + # col=granularity, + # ) + # ) + # time_filters = [] + + # if is_timeseries: + # timestamp = dttm_col.get_timestamp_expression( + # time_grain=time_grain, template_processor=template_processor + # ) + # # always put timestamp as the first column + # select_exprs.insert(0, timestamp) + # groupby_all_columns[timestamp.name] = timestamp + + # # Use main dttm column to support index with secondary dttm columns. + # if ( + # db_engine_spec.time_secondary_columns + # and self.main_dttm_col in self.dttm_cols + # and self.main_dttm_col != dttm_col.column_name + # ): + # time_filters.append( + # columns_by_name[self.main_dttm_col].get_time_filter( + # from_dttm, + # to_dttm, + # ) + # ) + # time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use @@ -1466,6 +1290,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods qry = sa.select(select_exprs) + # todo(hugh) fix templating + # tbl, cte = self.get_from_clause(template_processor) tbl, cte = self.get_from_clause(template_processor) if groupby_all_columns: @@ -1480,18 +1306,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods flt_col = flt["col"] val = flt.get("val") op = flt["op"].upper() - col_obj: Optional["TableColumn"] = None - sqla_col: Optional[Column] = None + col_obj: Optional[Column] = None + sqla_col: Optional[sa.Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif utils.is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore + sqla_col = self.adhoc_column_to_sqla(flt_col) else: col_obj = columns_by_name.get(flt_col) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if utils.get_column_name(flt_col) in removed_filters: + if get_column_name(flt_col) in removed_filters: # Skip generating SQLA filter when the jinja template handles it. continue @@ -1502,81 +1328,35 @@ class ExploreMixin: # pylint: disable=too-many-public-methods sqla_col = col_obj.get_timestamp_expression( time_grain=filter_grain, template_processor=template_processor ) - elif col_obj and isinstance(col_obj, dict): - sqla_col = sa.column(col_obj.get("column_name")) elif col_obj: sqla_col = col_obj.get_sqla_col() - - if col_obj and isinstance(col_obj, dict): - col_type = col_obj.get("type") - else: - col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( - native_type=col_type, - db_extra=self.database.get_extra(), # type: ignore + col_obj.type if col_obj else None ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - - if col_obj and isinstance(col_obj, dict): - col_advanced_data_type = "" + if col_spec: + target_type = col_spec.generic_type else: - col_advanced_data_type = ( - col_obj.advanced_data_type if col_obj else "" - ) - - if col_spec and not col_advanced_data_type: - target_generic_type = col_spec.generic_type - else: - target_generic_type = utils.GenericDataType.STRING + target_type = GenericDataType.STRING eq = self.filter_values_handler( values=val, - target_generic_type=target_generic_type, - target_native_type=col_type, + target_column_type=target_type, is_list_target=is_list_target, - db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), # type: ignore ) - if ( - col_advanced_data_type != "" - and feature_flag_manager.is_feature_enabled( - "ENABLE_ADVANCED_DATA_TYPES" - ) - and col_advanced_data_type in ADVANCED_DATA_TYPES - ): - values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( - { - "type": col_advanced_data_type, - "values": values, - } - ) - if bus_resp["error_message"]: - raise AdvancedDataTypeResponseError( - _(bus_resp["error_message"]) - ) - - where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) - ) - elif is_list_target: + if is_list_target: assert isinstance(eq, (tuple, list)) if len(eq) == 0: raise QueryObjectValidationError( _("Filter value list cannot be empty") ) - if len(eq) > len( - eq_without_none := [x for x in eq if x is not None] - ): + if None in eq: + eq = [x for x in eq if x is not None] is_null_cond = sqla_col.is_(None) if eq: - cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) + cond = or_(is_null_cond, sqla_col.in_(eq)) else: cond = is_null_cond else: @@ -1620,15 +1400,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - # todo(hugh): fix this w/ template_processor - # where_clause_and += self.get_sqla_row_level_filters(template_processor) + if is_feature_enabled("ROW_LEVEL_SECURITY"): + where_clause_and += self._get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template( # type: ignore - f"({where})" - ) + where = template_processor.process_template(where) except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1636,13 +1414,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods msg=ex.message, ) ) from ex - where_clause_and += [self.text(where)] + where_clause_and += [self.text(f"({where})")] having = extras.get("having") if having: try: - having = template_processor.process_template( # type: ignore - f"({having})" - ) + having = template_processor.process_template(having) except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1650,10 +1426,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods msg=ex.message, ) ) from ex - having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore - qry = qry.where(self.get_fetch_values_predicate()) # type: ignore + having_clause_and += [self.text(f"({having})")] + if apply_fetch_values_predicate and self.fetch_values_predicate: + qry = qry.where(self.get_fetch_values_predicate()) if granularity: + time_filters = ( + [] + ) # todo(hugh): remove this once time filters are actually set qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) @@ -1673,7 +1452,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and col.name in [select_col.name for select_col in select_exprs] ): col = literal_column(col.name) - direction = sa.asc if ascending else sa.desc + direction = asc if ascending else desc qry = qry.order_by(direction(col)) if row_limit: @@ -1692,13 +1471,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_series_columns.items(): - label = utils.get_column_name(gby_name) + label = get_column_name(gby_name) inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) inner_select_exprs += [inner_main_metric_expr] - subq = sa.select(inner_select_exprs).select_from(tbl) + subq = select(inner_select_exprs).select_from(tbl) inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: @@ -1712,7 +1491,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr - direction = sa.desc if order_desc else sa.asc + if series_limit_metric: + ob = self._get_series_orderby( + series_limit_metric, metrics_by_name, columns_by_name + ) + direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -1722,9 +1505,21 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # conditionally mutated, as it refers to the column alias in # the inner query col_name = db_engine_spec.make_label_compatible(gby_name + "__") - on_clause.append(gby_obj == sa.column(col_name)) + on_clause.append(gby_obj == column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + if series_limit_metric: + orderby = [ + ( + self._get_series_orderby( + series_limit_metric, + metrics_by_name, + columns_by_name, + ), + not order_desc, + ) + ] # run prequery to get top groups prequery_obj = { @@ -1742,7 +1537,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "order_desc": True, } - result = self.query(prequery_obj) # type: ignore + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c @@ -1763,7 +1558,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) label = "rowcount" col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) + qry = select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] return SqlaQuery(
