This is an automated email from the ASF dual-hosted git repository. hugh pushed a commit to branch fix-explore-mixin in repository https://gitbox.apache.org/repos/asf/superset.git
commit e386bc426c3080a48d9be1b5fafd8e6fbd84df63 Author: hughhhh <[email protected]> AuthorDate: Fri Jan 20 17:25:56 2023 +0200 re patch sqlatable into exploremixin --- superset/connectors/sqla/models.py | 1308 ++++++++++++----------- superset/models/helpers.py | 349 +++--- superset/models/sql_lab.py | 3 +- superset/result_set.py | 4 +- superset/utils/pandas_postprocessing/boxplot.py | 8 +- superset/utils/pandas_postprocessing/flatten.py | 2 +- 6 files changed, 854 insertions(+), 820 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c5fd025f4e..b363188b87 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -105,7 +105,12 @@ from superset.jinja_context import ( ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult +from superset.models.helpers import ( + AuditMixinNullable, + CertificationMixin, + ExploreMixin, + QueryResult, +) from superset.sql_parse import ParsedQuery, sanitize_clause from superset.superset_typing import ( AdhocColumn, @@ -149,12 +154,13 @@ class SqlaQuery(NamedTuple): prequeries: List[str] sqla_query: Select +from superset.models.helpers import QueryStringExtended -class QueryStringExtended(NamedTuple): - applied_template_filters: Optional[List[str]] - labels_expected: List[str] - prequeries: List[str] - sql: str +# class QueryStringExtended(NamedTuple): +# applied_template_filters: Optional[List[str]] +# labels_expected: List[str] +# prequeries: List[str] +# sql: str @dataclass @@ -534,7 +540,7 @@ def _process_sql_expression( return expression -class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods +class SqlaTable(Model, BaseDatasource, ExploreMixin): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" type = "table" @@ -980,7 +986,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho return self.make_sqla_column_compatible(sqla_metric, label) - def adhoc_column_to_sqla( + def adhoc_column_to_sqla( # type: ignore self, col: AdhocColumn, template_processor: Optional[BaseTemplateProcessor] = None, @@ -1118,649 +1124,649 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho def text(self, clause: str) -> TextClause: return self.db_engine_spec.get_text_clause(clause) - def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - self, - apply_fetch_values_predicate: bool = False, - columns: Optional[List[ColumnTyping]] = None, - extras: Optional[Dict[str, Any]] = None, - filter: Optional[ # pylint: disable=redefined-builtin - List[QueryObjectFilterClause] - ] = None, - from_dttm: Optional[datetime] = None, - granularity: Optional[str] = None, - groupby: Optional[List[Column]] = None, - inner_from_dttm: Optional[datetime] = None, - inner_to_dttm: Optional[datetime] = None, - is_rowcount: bool = False, - is_timeseries: bool = True, - metrics: Optional[List[Metric]] = None, - orderby: Optional[List[OrderBy]] = None, - order_desc: bool = True, - to_dttm: Optional[datetime] = None, - series_columns: Optional[List[Column]] = None, - series_limit: Optional[int] = None, - series_limit_metric: Optional[Metric] = None, - row_limit: Optional[int] = None, - row_offset: Optional[int] = None, - timeseries_limit: Optional[int] = None, - timeseries_limit_metric: Optional[Metric] = None, - time_shift: Optional[str] = None, - ) -> SqlaQuery: - """Querying any sqla table from this common interface""" - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - extras = extras or {} - time_grain = extras.get("time_grain_sqla") - - template_kwargs = { - "columns": columns, - "from_dttm": from_dttm.isoformat() if from_dttm else None, - "groupby": groupby, - "metrics": metrics, - "row_limit": row_limit, - "row_offset": row_offset, - "time_column": granularity, - "time_grain": time_grain, - "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.column_name for col in self.columns], - "filter": filter, - } - columns = columns or [] - groupby = groupby or [] - series_column_names = utils.get_column_names(series_columns or []) - # deprecated, to be removed in 2.0 - if is_timeseries and timeseries_limit: - series_limit = timeseries_limit - series_limit_metric = series_limit_metric or timeseries_limit_metric - template_kwargs.update(self.template_params_dict) - extra_cache_keys: List[Any] = [] - template_kwargs["extra_cache_keys"] = extra_cache_keys - removed_filters: List[str] = [] - applied_template_filters: List[str] = [] - template_kwargs["removed_filters"] = removed_filters - template_kwargs["applied_filters"] = applied_template_filters - template_processor = self.get_template_processor(**template_kwargs) - db_engine_spec = self.db_engine_spec - prequeries: List[str] = [] - orderby = orderby or [] - need_groupby = bool(metrics is not None or groupby) - metrics = metrics or [] - - # For backward compatibility - if granularity not in self.dttm_cols and granularity is not None: - granularity = self.main_dttm_col - - columns_by_name: Dict[str, TableColumn] = { - col.column_name: col for col in self.columns - } - - metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} - - if not granularity and is_timeseries: - raise QueryObjectValidationError( - _( - "Datetime column not provided as part table configuration " - "and is required by this type of chart" - ) - ) - if not metrics and not columns and not groupby: - raise QueryObjectValidationError(_("Empty query?")) - - metrics_exprs: List[ColumnElement] = [] - for metric in metrics: - if utils.is_adhoc_metric(metric): - assert isinstance(metric, dict) - metrics_exprs.append( - self.adhoc_metric_to_sqla( - metric=metric, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - ) - elif isinstance(metric, str) and metric in metrics_by_name: - metrics_exprs.append( - metrics_by_name[metric].get_sqla_col( - template_processor=template_processor - ) - ) - else: - raise QueryObjectValidationError( - _("Metric '%(metric)s' does not exist", metric=metric) - ) - - if metrics_exprs: - main_metric_expr = metrics_exprs[0] - else: - main_metric_expr, label = literal_column("COUNT(*)"), "ccount" - main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) - - # To ensure correct handling of the ORDER BY labeling we need to reference the - # metric instance if defined in the SELECT clause. - # use the key of the ColumnClause for the expected label - metrics_exprs_by_label = {m.key: m for m in metrics_exprs} - metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} - - # Since orderby may use adhoc metrics, too; we need to process them first - orderby_exprs: List[ColumnElement] = [] - for orig_col, ascending in orderby: - col: Union[AdhocMetric, ColumnElement] = orig_col - if isinstance(col, dict): - col = cast(AdhocMetric, col) - if col.get("sqlExpression"): - col["sqlExpression"] = _process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - if utils.is_adhoc_metric(col): - # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) - # if the adhoc metric has been defined before - # use the existing instance. - col = metrics_exprs_by_expr.get(str(col), col) - need_groupby = True - elif col in columns_by_name: - col = columns_by_name[col].get_sqla_col( - template_processor=template_processor - ) - elif col in metrics_exprs_by_label: - col = metrics_exprs_by_label[col] - need_groupby = True - elif col in metrics_by_name: - col = metrics_by_name[col].get_sqla_col( - template_processor=template_processor - ) - need_groupby = True - - if isinstance(col, ColumnElement): - orderby_exprs.append(col) - else: - # Could not convert a column reference to valid ColumnElement - raise QueryObjectValidationError( - _("Unknown column used in orderby: %(col)s", col=orig_col) - ) - - select_exprs: List[Union[Column, Label]] = [] - groupby_all_columns = {} - groupby_series_columns = {} - - # filter out the pseudo column __timestamp from columns - columns = [col for col in columns if col != utils.DTTM_ALIAS] - dttm_col = columns_by_name.get(granularity) if granularity else None - - if need_groupby: - # dedup columns while preserving order - columns = groupby or columns - for selected in columns: - if isinstance(selected, str): - # if groupby field/expr equals granularity field/expr - if selected == granularity: - table_col = columns_by_name[selected] - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - # if groupby field equals a selected column - elif selected in columns_by_name: - outer = columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - else: - selected = validate_adhoc_subquery( - selected, - self.database_id, - self.schema, - ) - outer = literal_column(f"({selected})") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = self.adhoc_column_to_sqla( - col=selected, template_processor=template_processor - ) - groupby_all_columns[outer.name] = outer - if ( - is_timeseries and not series_column_names - ) or outer.name in series_column_names: - groupby_series_columns[outer.name] = outer - select_exprs.append(outer) - elif columns: - for selected in columns: - if is_adhoc_column(selected): - _sql = selected["sqlExpression"] - _column_label = selected["label"] - elif isinstance(selected, str): - _sql = selected - _column_label = selected - - selected = validate_adhoc_subquery( - _sql, - self.database_id, - self.schema, - ) - select_exprs.append( - columns_by_name[selected].get_sqla_col( - template_processor=template_processor - ) - if isinstance(selected, str) and selected in columns_by_name - else self.make_sqla_column_compatible( - literal_column(selected), _column_label - ) - ) - metrics_exprs = [] - - if granularity: - if granularity not in columns_by_name or not dttm_col: - raise QueryObjectValidationError( - _( - 'Time column "%(col)s" does not exist in dataset', - col=granularity, - ) - ) - time_filters = [] - - if is_timeseries: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) - # always put timestamp as the first column - select_exprs.insert(0, timestamp) - groupby_all_columns[timestamp.name] = timestamp - - # Use main dttm column to support index with secondary dttm columns. - if ( - db_engine_spec.time_secondary_columns - and self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col.column_name - ): - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - time_filters.append( - dttm_col.get_time_filter( - start_dttm=from_dttm, - end_dttm=to_dttm, - template_processor=template_processor, - ) - ) - - # Always remove duplicates by column name, as sometimes `metrics_exprs` - # can have the same name as a groupby column (e.g. when users use - # raw columns as custom SQL adhoc metric). - select_exprs = remove_duplicates( - select_exprs + metrics_exprs, key=lambda x: x.name - ) - - # Expected output columns - labels_expected = [c.key for c in select_exprs] - - # Order by columns are "hidden" columns, some databases require them - # always be present in SELECT if an aggregation function is used - if not db_engine_spec.allows_hidden_ordeby_agg: - select_exprs = remove_duplicates(select_exprs + orderby_exprs) - - qry = sa.select(select_exprs) - - tbl, cte = self.get_from_clause(template_processor) - - if groupby_all_columns: - qry = qry.group_by(*groupby_all_columns.values()) - - where_clause_and = [] - having_clause_and = [] - - for flt in filter: # type: ignore - if not all(flt.get(s) for s in ["col", "op"]): - continue - flt_col = flt["col"] - val = flt.get("val") - op = flt["op"].upper() - col_obj: Optional[TableColumn] = None - sqla_col: Optional[Column] = None - if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: - col_obj = dttm_col - elif is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) - else: - col_obj = columns_by_name.get(flt_col) - filter_grain = flt.get("grain") - - if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if get_column_name(flt_col) in removed_filters: - # Skip generating SQLA filter when the jinja template handles it. - continue - - if col_obj or sqla_col is not None: - if sqla_col is not None: - pass - elif col_obj and filter_grain: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, template_processor=template_processor - ) - elif col_obj: - sqla_col = col_obj.get_sqla_col( - template_processor=template_processor - ) - col_type = col_obj.type if col_obj else None - col_spec = db_engine_spec.get_column_spec( - native_type=col_type, - db_extra=self.database.get_extra(), - ) - is_list_target = op in ( - utils.FilterOperator.IN.value, - utils.FilterOperator.NOT_IN.value, - ) - - col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" - - if col_spec and not col_advanced_data_type: - target_generic_type = col_spec.generic_type - else: - target_generic_type = GenericDataType.STRING - eq = self.filter_values_handler( - values=val, - operator=op, - target_generic_type=target_generic_type, - target_native_type=col_type, - is_list_target=is_list_target, - db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), - ) - if ( - col_advanced_data_type != "" - and feature_flag_manager.is_feature_enabled( - "ENABLE_ADVANCED_DATA_TYPES" - ) - and col_advanced_data_type in ADVANCED_DATA_TYPES - ): - values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( - { - "type": col_advanced_data_type, - "values": values, - } - ) - if bus_resp["error_message"]: - raise AdvancedDataTypeResponseError( - _(bus_resp["error_message"]) - ) - - where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) - ) - elif is_list_target: - assert isinstance(eq, (tuple, list)) - if len(eq) == 0: - raise QueryObjectValidationError( - _("Filter value list cannot be empty") - ) - if len(eq) > len( - eq_without_none := [x for x in eq if x is not None] - ): - is_null_cond = sqla_col.is_(None) - if eq: - cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) - else: - cond = is_null_cond - else: - cond = sqla_col.in_(eq) - if op == utils.FilterOperator.NOT_IN.value: - cond = ~cond - where_clause_and.append(cond) - elif op == utils.FilterOperator.IS_NULL.value: - where_clause_and.append(sqla_col.is_(None)) - elif op == utils.FilterOperator.IS_NOT_NULL.value: - where_clause_and.append(sqla_col.isnot(None)) - elif op == utils.FilterOperator.IS_TRUE.value: - where_clause_and.append(sqla_col.is_(True)) - elif op == utils.FilterOperator.IS_FALSE.value: - where_clause_and.append(sqla_col.is_(False)) - else: - if ( - op - not in { - utils.FilterOperator.EQUALS.value, - utils.FilterOperator.NOT_EQUALS.value, - } - and eq is None - ): - raise QueryObjectValidationError( - _( - "Must specify a value for filters " - "with comparison operators" - ) - ) - if op == utils.FilterOperator.EQUALS.value: - where_clause_and.append(sqla_col == eq) - elif op == utils.FilterOperator.NOT_EQUALS.value: - where_clause_and.append(sqla_col != eq) - elif op == utils.FilterOperator.GREATER_THAN.value: - where_clause_and.append(sqla_col > eq) - elif op == utils.FilterOperator.LESS_THAN.value: - where_clause_and.append(sqla_col < eq) - elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col >= eq) - elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: - where_clause_and.append(sqla_col <= eq) - elif op == utils.FilterOperator.LIKE.value: - where_clause_and.append(sqla_col.like(eq)) - elif op == utils.FilterOperator.ILIKE.value: - where_clause_and.append(sqla_col.ilike(eq)) - elif ( - op == utils.FilterOperator.TEMPORAL_RANGE.value - and isinstance(eq, str) - and col_obj is not None - ): - _since, _until = get_since_until_from_time_range( - time_range=eq, - time_shift=time_shift, - extras=extras, - ) - where_clause_and.append( - col_obj.get_time_filter( - start_dttm=_since, - end_dttm=_until, - label=sqla_col.key, - template_processor=template_processor, - ) - ) - else: - raise QueryObjectValidationError( - _("Invalid filter operation type: %(op)s", op=op) - ) - where_clause_and += self.get_sqla_row_level_filters(template_processor) - if extras: - where = extras.get("where") - if where: - try: - where = template_processor.process_template(f"({where})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in WHERE clause: %(msg)s", - msg=ex.message, - ) - ) from ex - where = _process_sql_expression( - expression=where, - database_id=self.database_id, - schema=self.schema, - ) - where_clause_and += [self.text(where)] - having = extras.get("having") - if having: - try: - having = template_processor.process_template(f"({having})") - except TemplateError as ex: - raise QueryObjectValidationError( - _( - "Error in jinja expression in HAVING clause: %(msg)s", - msg=ex.message, - ) - ) from ex - having = _process_sql_expression( - expression=having, - database_id=self.database_id, - schema=self.schema, - ) - having_clause_and += [self.text(having)] - - if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where( - self.get_fetch_values_predicate(template_processor=template_processor) - ) - if granularity: - qry = qry.where(and_(*(time_filters + where_clause_and))) - else: - qry = qry.where(and_(*where_clause_and)) - qry = qry.having(and_(*having_clause_and)) - - self.make_orderby_compatible(select_exprs, orderby_exprs) - - for col, (orig_col, ascending) in zip(orderby_exprs, orderby): - if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): - # if engine does not allow using SELECT alias in ORDER BY - # revert to the underlying column - col = col.element - - if ( - db_engine_spec.allows_alias_in_select - and db_engine_spec.allows_hidden_cc_in_orderby - and col.name in [select_col.name for select_col in select_exprs] - ): - col = literal_column(col.name) - direction = asc if ascending else desc - qry = qry.order_by(direction(col)) - - if row_limit: - qry = qry.limit(row_limit) - if row_offset: - qry = qry.offset(row_offset) - - if series_limit and groupby_series_columns: - if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: - # some sql dialects require for order by expressions - # to also be in the select clause -- others, e.g. vertica, - # require a unique inner alias - inner_main_metric_expr = self.make_sqla_column_compatible( - main_metric_expr, "mme_inner__" - ) - inner_groupby_exprs = [] - inner_select_exprs = [] - for gby_name, gby_obj in groupby_series_columns.items(): - label = get_column_name(gby_name) - inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") - inner_groupby_exprs.append(inner) - inner_select_exprs.append(inner) - - inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs).select_from(tbl) - inner_time_filter = [] - - if dttm_col and not db_engine_spec.time_groupby_inline: - inner_time_filter = [ - dttm_col.get_time_filter( - start_dttm=inner_from_dttm or from_dttm, - end_dttm=inner_to_dttm or to_dttm, - template_processor=template_processor, - ) - ] - subq = subq.where(and_(*(where_clause_and + inner_time_filter))) - subq = subq.group_by(*inner_groupby_exprs) - - ob = inner_main_metric_expr - if series_limit_metric: - ob = self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ) - direction = desc if order_desc else asc - subq = subq.order_by(direction(ob)) - subq = subq.limit(series_limit) - - on_clause = [] - for gby_name, gby_obj in groupby_series_columns.items(): - # in this case the column name, not the alias, needs to be - # conditionally mutated, as it refers to the column alias in - # the inner query - col_name = db_engine_spec.make_label_compatible(gby_name + "__") - on_clause.append(gby_obj == column(col_name)) - - tbl = tbl.join(subq.alias(), and_(*on_clause)) - else: - if series_limit_metric: - orderby = [ - ( - self._get_series_orderby( - series_limit_metric=series_limit_metric, - metrics_by_name=metrics_by_name, - columns_by_name=columns_by_name, - template_processor=template_processor, - ), - not order_desc, - ) - ] - - # run prequery to get top groups - prequery_obj = { - "is_timeseries": False, - "row_limit": series_limit, - "metrics": metrics, - "granularity": granularity, - "groupby": groupby, - "from_dttm": inner_from_dttm or from_dttm, - "to_dttm": inner_to_dttm or to_dttm, - "filter": filter, - "orderby": orderby, - "extras": extras, - "columns": columns, - "order_desc": True, - } - - result = self.query(prequery_obj) - prequeries.append(result.query) - dimensions = [ - c - for c in result.df.columns - if c not in metrics and c in groupby_series_columns - ] - top_groups = self._get_top_groups( - result.df, dimensions, groupby_series_columns, columns_by_name - ) - qry = qry.where(top_groups) - - qry = qry.select_from(tbl) - - if is_rowcount: - if not db_engine_spec.allows_subqueries: - raise QueryObjectValidationError( - _("Database does not support subqueries") - ) - label = "rowcount" - col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = select([col]).select_from(qry.alias("rowcount_qry")) - labels_expected = [label] - - return SqlaQuery( - applied_template_filters=applied_template_filters, - cte=cte, - extra_cache_keys=extra_cache_keys, - labels_expected=labels_expected, - sqla_query=qry, - prequeries=prequeries, - ) + # def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements + # self, + # apply_fetch_values_predicate: bool = False, + # columns: Optional[List[ColumnTyping]] = None, + # extras: Optional[Dict[str, Any]] = None, + # filter: Optional[ # pylint: disable=redefined-builtin + # List[QueryObjectFilterClause] + # ] = None, + # from_dttm: Optional[datetime] = None, + # granularity: Optional[str] = None, + # groupby: Optional[List[Column]] = None, + # inner_from_dttm: Optional[datetime] = None, + # inner_to_dttm: Optional[datetime] = None, + # is_rowcount: bool = False, + # is_timeseries: bool = True, + # metrics: Optional[List[Metric]] = None, + # orderby: Optional[List[OrderBy]] = None, + # order_desc: bool = True, + # to_dttm: Optional[datetime] = None, + # series_columns: Optional[List[Column]] = None, + # series_limit: Optional[int] = None, + # series_limit_metric: Optional[Metric] = None, + # row_limit: Optional[int] = None, + # row_offset: Optional[int] = None, + # timeseries_limit: Optional[int] = None, + # timeseries_limit_metric: Optional[Metric] = None, + # time_shift: Optional[str] = None, + # ) -> SqlaQuery: + # """Querying any sqla table from this common interface""" + # if granularity not in self.dttm_cols and granularity is not None: + # granularity = self.main_dttm_col + + # extras = extras or {} + # time_grain = extras.get("time_grain_sqla") + + # template_kwargs = { + # "columns": columns, + # "from_dttm": from_dttm.isoformat() if from_dttm else None, + # "groupby": groupby, + # "metrics": metrics, + # "row_limit": row_limit, + # "row_offset": row_offset, + # "time_column": granularity, + # "time_grain": time_grain, + # "to_dttm": to_dttm.isoformat() if to_dttm else None, + # "table_columns": [col.column_name for col in self.columns], + # "filter": filter, + # } + # columns = columns or [] + # groupby = groupby or [] + # series_column_names = utils.get_column_names(series_columns or []) + # # deprecated, to be removed in 2.0 + # if is_timeseries and timeseries_limit: + # series_limit = timeseries_limit + # series_limit_metric = series_limit_metric or timeseries_limit_metric + # template_kwargs.update(self.template_params_dict) + # extra_cache_keys: List[Any] = [] + # template_kwargs["extra_cache_keys"] = extra_cache_keys + # removed_filters: List[str] = [] + # applied_template_filters: List[str] = [] + # template_kwargs["removed_filters"] = removed_filters + # template_kwargs["applied_filters"] = applied_template_filters + # template_processor = self.get_template_processor(**template_kwargs) + # db_engine_spec = self.db_engine_spec + # prequeries: List[str] = [] + # orderby = orderby or [] + # need_groupby = bool(metrics is not None or groupby) + # metrics = metrics or [] + + # # For backward compatibility + # if granularity not in self.dttm_cols and granularity is not None: + # granularity = self.main_dttm_col + + # columns_by_name: Dict[str, TableColumn] = { + # col.column_name: col for col in self.columns + # } + + # metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + + # if not granularity and is_timeseries: + # raise QueryObjectValidationError( + # _( + # "Datetime column not provided as part table configuration " + # "and is required by this type of chart" + # ) + # ) + # if not metrics and not columns and not groupby: + # raise QueryObjectValidationError(_("Empty query?")) + + # metrics_exprs: List[ColumnElement] = [] + # for metric in metrics: + # if utils.is_adhoc_metric(metric): + # assert isinstance(metric, dict) + # metrics_exprs.append( + # self.adhoc_metric_to_sqla( + # metric=metric, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ) + # ) + # elif isinstance(metric, str) and metric in metrics_by_name: + # metrics_exprs.append( + # metrics_by_name[metric].get_sqla_col( + # template_processor=template_processor + # ) + # ) + # else: + # raise QueryObjectValidationError( + # _("Metric '%(metric)s' does not exist", metric=metric) + # ) + + # if metrics_exprs: + # main_metric_expr = metrics_exprs[0] + # else: + # main_metric_expr, label = literal_column("COUNT(*)"), "ccount" + # main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) + + # # To ensure correct handling of the ORDER BY labeling we need to reference the + # # metric instance if defined in the SELECT clause. + # # use the key of the ColumnClause for the expected label + # metrics_exprs_by_label = {m.key: m for m in metrics_exprs} + # metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} + + # # Since orderby may use adhoc metrics, too; we need to process them first + # orderby_exprs: List[ColumnElement] = [] + # for orig_col, ascending in orderby: + # col: Union[AdhocMetric, ColumnElement] = orig_col + # if isinstance(col, dict): + # col = cast(AdhocMetric, col) + # if col.get("sqlExpression"): + # col["sqlExpression"] = _process_sql_expression( + # expression=col["sqlExpression"], + # database_id=self.database_id, + # schema=self.schema, + # template_processor=template_processor, + # ) + # if utils.is_adhoc_metric(col): + # # add adhoc sort by column to columns_by_name if not exists + # col = self.adhoc_metric_to_sqla(col, columns_by_name) + # # if the adhoc metric has been defined before + # # use the existing instance. + # col = metrics_exprs_by_expr.get(str(col), col) + # need_groupby = True + # elif col in columns_by_name: + # col = columns_by_name[col].get_sqla_col( + # template_processor=template_processor + # ) + # elif col in metrics_exprs_by_label: + # col = metrics_exprs_by_label[col] + # need_groupby = True + # elif col in metrics_by_name: + # col = metrics_by_name[col].get_sqla_col( + # template_processor=template_processor + # ) + # need_groupby = True + + # if isinstance(col, ColumnElement): + # orderby_exprs.append(col) + # else: + # # Could not convert a column reference to valid ColumnElement + # raise QueryObjectValidationError( + # _("Unknown column used in orderby: %(col)s", col=orig_col) + # ) + + # select_exprs: List[Union[Column, Label]] = [] + # groupby_all_columns = {} + # groupby_series_columns = {} + + # # filter out the pseudo column __timestamp from columns + # columns = [col for col in columns if col != utils.DTTM_ALIAS] + # dttm_col = columns_by_name.get(granularity) if granularity else None + + # if need_groupby: + # # dedup columns while preserving order + # columns = groupby or columns + # for selected in columns: + # if isinstance(selected, str): + # # if groupby field/expr equals granularity field/expr + # if selected == granularity: + # table_col = columns_by_name[selected] + # outer = table_col.get_timestamp_expression( + # time_grain=time_grain, + # label=selected, + # template_processor=template_processor, + # ) + # # if groupby field equals a selected column + # elif selected in columns_by_name: + # outer = columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) + # else: + # selected = validate_adhoc_subquery( + # selected, + # self.database_id, + # self.schema, + # ) + # outer = literal_column(f"({selected})") + # outer = self.make_sqla_column_compatible(outer, selected) + # else: + # outer = self.adhoc_column_to_sqla( + # col=selected, template_processor=template_processor + # ) + # groupby_all_columns[outer.name] = outer + # if ( + # is_timeseries and not series_column_names + # ) or outer.name in series_column_names: + # groupby_series_columns[outer.name] = outer + # select_exprs.append(outer) + # elif columns: + # for selected in columns: + # if is_adhoc_column(selected): + # _sql = selected["sqlExpression"] + # _column_label = selected["label"] + # elif isinstance(selected, str): + # _sql = selected + # _column_label = selected + + # selected = validate_adhoc_subquery( + # _sql, + # self.database_id, + # self.schema, + # ) + # select_exprs.append( + # columns_by_name[selected].get_sqla_col( + # template_processor=template_processor + # ) + # if isinstance(selected, str) and selected in columns_by_name + # else self.make_sqla_column_compatible( + # literal_column(selected), _column_label + # ) + # ) + # metrics_exprs = [] + + # if granularity: + # if granularity not in columns_by_name or not dttm_col: + # raise QueryObjectValidationError( + # _( + # 'Time column "%(col)s" does not exist in dataset', + # col=granularity, + # ) + # ) + # time_filters = [] + + # if is_timeseries: + # timestamp = dttm_col.get_timestamp_expression( + # time_grain=time_grain, template_processor=template_processor + # ) + # # always put timestamp as the first column + # select_exprs.insert(0, timestamp) + # groupby_all_columns[timestamp.name] = timestamp + + # # Use main dttm column to support index with secondary dttm columns. + # if ( + # db_engine_spec.time_secondary_columns + # and self.main_dttm_col in self.dttm_cols + # and self.main_dttm_col != dttm_col.column_name + # ): + # time_filters.append( + # columns_by_name[self.main_dttm_col].get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + # ) + # time_filters.append( + # dttm_col.get_time_filter( + # start_dttm=from_dttm, + # end_dttm=to_dttm, + # template_processor=template_processor, + # ) + # ) + + # # Always remove duplicates by column name, as sometimes `metrics_exprs` + # # can have the same name as a groupby column (e.g. when users use + # # raw columns as custom SQL adhoc metric). + # select_exprs = remove_duplicates( + # select_exprs + metrics_exprs, key=lambda x: x.name + # ) + + # # Expected output columns + # labels_expected = [c.key for c in select_exprs] + + # # Order by columns are "hidden" columns, some databases require them + # # always be present in SELECT if an aggregation function is used + # if not db_engine_spec.allows_hidden_ordeby_agg: + # select_exprs = remove_duplicates(select_exprs + orderby_exprs) + + # qry = sa.select(select_exprs) + + # tbl, cte = self.get_from_clause(template_processor) + + # if groupby_all_columns: + # qry = qry.group_by(*groupby_all_columns.values()) + + # where_clause_and = [] + # having_clause_and = [] + + # for flt in filter: # type: ignore + # if not all(flt.get(s) for s in ["col", "op"]): + # continue + # flt_col = flt["col"] + # val = flt.get("val") + # op = flt["op"].upper() + # col_obj: Optional[TableColumn] = None + # sqla_col: Optional[Column] = None + # if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: + # col_obj = dttm_col + # elif is_adhoc_column(flt_col): + # sqla_col = self.adhoc_column_to_sqla(flt_col) + # else: + # col_obj = columns_by_name.get(flt_col) + # filter_grain = flt.get("grain") + + # if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): + # if get_column_name(flt_col) in removed_filters: + # # Skip generating SQLA filter when the jinja template handles it. + # continue + + # if col_obj or sqla_col is not None: + # if sqla_col is not None: + # pass + # elif col_obj and filter_grain: + # sqla_col = col_obj.get_timestamp_expression( + # time_grain=filter_grain, template_processor=template_processor + # ) + # elif col_obj: + # sqla_col = col_obj.get_sqla_col( + # template_processor=template_processor + # ) + # col_type = col_obj.type if col_obj else None + # col_spec = db_engine_spec.get_column_spec( + # native_type=col_type, + # db_extra=self.database.get_extra(), + # ) + # is_list_target = op in ( + # utils.FilterOperator.IN.value, + # utils.FilterOperator.NOT_IN.value, + # ) + + # col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" + + # if col_spec and not col_advanced_data_type: + # target_generic_type = col_spec.generic_type + # else: + # target_generic_type = GenericDataType.STRING + # eq = self.filter_values_handler( + # values=val, + # operator=op, + # target_generic_type=target_generic_type, + # target_native_type=col_type, + # is_list_target=is_list_target, + # db_engine_spec=db_engine_spec, + # db_extra=self.database.get_extra(), + # ) + # if ( + # col_advanced_data_type != "" + # and feature_flag_manager.is_feature_enabled( + # "ENABLE_ADVANCED_DATA_TYPES" + # ) + # and col_advanced_data_type in ADVANCED_DATA_TYPES + # ): + # values = eq if is_list_target else [eq] # type: ignore + # bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ + # col_advanced_data_type + # ].translate_type( + # { + # "type": col_advanced_data_type, + # "values": values, + # } + # ) + # if bus_resp["error_message"]: + # raise AdvancedDataTypeResponseError( + # _(bus_resp["error_message"]) + # ) + + # where_clause_and.append( + # ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( + # sqla_col, op, bus_resp["values"] + # ) + # ) + # elif is_list_target: + # assert isinstance(eq, (tuple, list)) + # if len(eq) == 0: + # raise QueryObjectValidationError( + # _("Filter value list cannot be empty") + # ) + # if len(eq) > len( + # eq_without_none := [x for x in eq if x is not None] + # ): + # is_null_cond = sqla_col.is_(None) + # if eq: + # cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) + # else: + # cond = is_null_cond + # else: + # cond = sqla_col.in_(eq) + # if op == utils.FilterOperator.NOT_IN.value: + # cond = ~cond + # where_clause_and.append(cond) + # elif op == utils.FilterOperator.IS_NULL.value: + # where_clause_and.append(sqla_col.is_(None)) + # elif op == utils.FilterOperator.IS_NOT_NULL.value: + # where_clause_and.append(sqla_col.isnot(None)) + # elif op == utils.FilterOperator.IS_TRUE.value: + # where_clause_and.append(sqla_col.is_(True)) + # elif op == utils.FilterOperator.IS_FALSE.value: + # where_clause_and.append(sqla_col.is_(False)) + # else: + # if ( + # op + # not in { + # utils.FilterOperator.EQUALS.value, + # utils.FilterOperator.NOT_EQUALS.value, + # } + # and eq is None + # ): + # raise QueryObjectValidationError( + # _( + # "Must specify a value for filters " + # "with comparison operators" + # ) + # ) + # if op == utils.FilterOperator.EQUALS.value: + # where_clause_and.append(sqla_col == eq) + # elif op == utils.FilterOperator.NOT_EQUALS.value: + # where_clause_and.append(sqla_col != eq) + # elif op == utils.FilterOperator.GREATER_THAN.value: + # where_clause_and.append(sqla_col > eq) + # elif op == utils.FilterOperator.LESS_THAN.value: + # where_clause_and.append(sqla_col < eq) + # elif op == utils.FilterOperator.GREATER_THAN_OR_EQUALS.value: + # where_clause_and.append(sqla_col >= eq) + # elif op == utils.FilterOperator.LESS_THAN_OR_EQUALS.value: + # where_clause_and.append(sqla_col <= eq) + # elif op == utils.FilterOperator.LIKE.value: + # where_clause_and.append(sqla_col.like(eq)) + # elif op == utils.FilterOperator.ILIKE.value: + # where_clause_and.append(sqla_col.ilike(eq)) + # elif ( + # op == utils.FilterOperator.TEMPORAL_RANGE.value + # and isinstance(eq, str) + # and col_obj is not None + # ): + # _since, _until = get_since_until_from_time_range( + # time_range=eq, + # time_shift=time_shift, + # extras=extras, + # ) + # where_clause_and.append( + # col_obj.get_time_filter( + # start_dttm=_since, + # end_dttm=_until, + # label=sqla_col.key, + # template_processor=template_processor, + # ) + # ) + # else: + # raise QueryObjectValidationError( + # _("Invalid filter operation type: %(op)s", op=op) + # ) + # where_clause_and += self.get_sqla_row_level_filters(template_processor) + # if extras: + # where = extras.get("where") + # if where: + # try: + # where = template_processor.process_template(f"({where})") + # except TemplateError as ex: + # raise QueryObjectValidationError( + # _( + # "Error in jinja expression in WHERE clause: %(msg)s", + # msg=ex.message, + # ) + # ) from ex + # where = _process_sql_expression( + # expression=where, + # database_id=self.database_id, + # schema=self.schema, + # ) + # where_clause_and += [self.text(where)] + # having = extras.get("having") + # if having: + # try: + # having = template_processor.process_template(f"({having})") + # except TemplateError as ex: + # raise QueryObjectValidationError( + # _( + # "Error in jinja expression in HAVING clause: %(msg)s", + # msg=ex.message, + # ) + # ) from ex + # having = _process_sql_expression( + # expression=having, + # database_id=self.database_id, + # schema=self.schema, + # ) + # having_clause_and += [self.text(having)] + + # if apply_fetch_values_predicate and self.fetch_values_predicate: + # qry = qry.where( + # self.get_fetch_values_predicate(template_processor=template_processor) + # ) + # if granularity: + # qry = qry.where(and_(*(time_filters + where_clause_and))) + # else: + # qry = qry.where(and_(*where_clause_and)) + # qry = qry.having(and_(*having_clause_and)) + + # self.make_orderby_compatible(select_exprs, orderby_exprs) + + # for col, (orig_col, ascending) in zip(orderby_exprs, orderby): + # if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label): + # # if engine does not allow using SELECT alias in ORDER BY + # # revert to the underlying column + # col = col.element + + # if ( + # db_engine_spec.allows_alias_in_select + # and db_engine_spec.allows_hidden_cc_in_orderby + # and col.name in [select_col.name for select_col in select_exprs] + # ): + # col = literal_column(col.name) + # direction = asc if ascending else desc + # qry = qry.order_by(direction(col)) + + # if row_limit: + # qry = qry.limit(row_limit) + # if row_offset: + # qry = qry.offset(row_offset) + + # if series_limit and groupby_series_columns: + # if db_engine_spec.allows_joins and db_engine_spec.allows_subqueries: + # # some sql dialects require for order by expressions + # # to also be in the select clause -- others, e.g. vertica, + # # require a unique inner alias + # inner_main_metric_expr = self.make_sqla_column_compatible( + # main_metric_expr, "mme_inner__" + # ) + # inner_groupby_exprs = [] + # inner_select_exprs = [] + # for gby_name, gby_obj in groupby_series_columns.items(): + # label = get_column_name(gby_name) + # inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") + # inner_groupby_exprs.append(inner) + # inner_select_exprs.append(inner) + + # inner_select_exprs += [inner_main_metric_expr] + # subq = select(inner_select_exprs).select_from(tbl) + # inner_time_filter = [] + + # if dttm_col and not db_engine_spec.time_groupby_inline: + # inner_time_filter = [ + # dttm_col.get_time_filter( + # start_dttm=inner_from_dttm or from_dttm, + # end_dttm=inner_to_dttm or to_dttm, + # template_processor=template_processor, + # ) + # ] + # subq = subq.where(and_(*(where_clause_and + inner_time_filter))) + # subq = subq.group_by(*inner_groupby_exprs) + + # ob = inner_main_metric_expr + # if series_limit_metric: + # ob = self._get_series_orderby( + # series_limit_metric=series_limit_metric, + # metrics_by_name=metrics_by_name, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ) + # direction = desc if order_desc else asc + # subq = subq.order_by(direction(ob)) + # subq = subq.limit(series_limit) + + # on_clause = [] + # for gby_name, gby_obj in groupby_series_columns.items(): + # # in this case the column name, not the alias, needs to be + # # conditionally mutated, as it refers to the column alias in + # # the inner query + # col_name = db_engine_spec.make_label_compatible(gby_name + "__") + # on_clause.append(gby_obj == column(col_name)) + + # tbl = tbl.join(subq.alias(), and_(*on_clause)) + # else: + # if series_limit_metric: + # orderby = [ + # ( + # self._get_series_orderby( + # series_limit_metric=series_limit_metric, + # metrics_by_name=metrics_by_name, + # columns_by_name=columns_by_name, + # template_processor=template_processor, + # ), + # not order_desc, + # ) + # ] + + # # run prequery to get top groups + # prequery_obj = { + # "is_timeseries": False, + # "row_limit": series_limit, + # "metrics": metrics, + # "granularity": granularity, + # "groupby": groupby, + # "from_dttm": inner_from_dttm or from_dttm, + # "to_dttm": inner_to_dttm or to_dttm, + # "filter": filter, + # "orderby": orderby, + # "extras": extras, + # "columns": columns, + # "order_desc": True, + # } + + # result = self.query(prequery_obj) + # prequeries.append(result.query) + # dimensions = [ + # c + # for c in result.df.columns + # if c not in metrics and c in groupby_series_columns + # ] + # top_groups = self._get_top_groups( + # result.df, dimensions, groupby_series_columns, columns_by_name + # ) + # qry = qry.where(top_groups) + + # qry = qry.select_from(tbl) + + # if is_rowcount: + # if not db_engine_spec.allows_subqueries: + # raise QueryObjectValidationError( + # _("Database does not support subqueries") + # ) + # label = "rowcount" + # col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) + # qry = select([col]).select_from(qry.alias("rowcount_qry")) + # labels_expected = [label] + + # return SqlaQuery( + # applied_template_filters=applied_template_filters, + # cte=cte, + # extra_cache_keys=extra_cache_keys, + # labels_expected=labels_expected, + # sqla_query=qry, + # prequeries=prequeries, + # ) def _get_series_orderby( self, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index fd0a1eff5c..26b07c6e54 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -26,8 +26,8 @@ from typing import ( Any, cast, Dict, + Hashable, List, - Mapping, NamedTuple, Optional, Set, @@ -87,7 +87,13 @@ from superset.superset_typing import ( QueryObjectDict, ) from superset.utils import core as utils -from superset.utils.core import get_user_id +from superset.utils.core import ( + GenericDataType, + get_column_name, + get_user_id, + is_adhoc_column, + remove_duplicates, +) if TYPE_CHECKING: from superset.connectors.sqla.models import SqlMetric, TableColumn @@ -680,7 +686,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods } @property - def query(self) -> str: + def fetch_value_predicate(self) -> str: + return "fix this!" + + def query(self, query_obj: QueryObjectDict) -> QueryResult: raise NotImplementedError() @property @@ -747,13 +756,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_fetch_values_predicate(self) -> List[Any]: raise NotImplementedError() - @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: raise NotImplementedError() + def get_sqla_row_level_filters( + self, + template_processor: BaseTemplateProcessor, + ) -> List[TextClause]: + raise NotImplementedError() + def _process_sql_expression( # pylint: disable=no-self-use self, expression: Optional[str], @@ -1156,13 +1170,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Mapping[str, "SqlMetric"], - columns_by_name: Mapping[str, "TableColumn"], + metrics_by_name: Dict[str, "SqlMetric"], + columns_by_name: Dict[str, "TableColumn"], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) ob = self.adhoc_metric_to_sqla( - series_limit_metric, columns_by_name # type: ignore + series_limit_metric, columns_by_name ) elif ( isinstance(series_limit_metric, str) @@ -1180,23 +1195,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods col: Type["AdhocColumn"], # type: ignore template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: - """ - Turn an adhoc column into a sqlalchemy column. - - :param col: Adhoc column definition - :param template_processor: template_processor instance - :returns: The metric defined as a sqlalchemy column - :rtype: sqlalchemy.sql.column - """ - label = utils.get_column_name(col) # type: ignore - expression = self._process_sql_expression( - expression=col["sqlExpression"], - database_id=self.database_id, - schema=self.schema, - template_processor=template_processor, - ) - sqla_column = literal_column(expression) - return self.make_sqla_column_compatible(sqla_column, label) + raise NotImplementedError() + # """ + # Turn an adhoc column into a sqlalchemy column. + + # :param col: Adhoc column definition + # :param template_processor: template_processor instance + # :returns: The metric defined as a sqlalchemy column + # :rtype: sqlalchemy.sql.column + # """ + # label = utils.get_column_name(col) # type: ignore + # expression = self._process_sql_expression( + # expression=col["sqlExpression"], + # database_id=self.database_id, + # schema=self.schema, + # template_processor=template_processor, + # ) + # sqla_column = literal_column(expression) + # return self.make_sqla_column_compatible(sqla_column, label) def _get_top_groups( self, @@ -1371,7 +1387,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": [col.get("column_name") for col in self.columns], + "table_columns": [col.column_name for col in self.columns], "filter": filter, } columns = columns or [] @@ -1399,11 +1415,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { - col.get("column_name"): col - for col in self.columns # col.column_name: col for col in self.columns + columns_by_name: Dict[str, TableColumn] = { + col.column_name: col for col in self.columns } + metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + if not granularity and is_timeseries: raise QueryObjectValidationError( _( @@ -1425,6 +1442,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods template_processor=template_processor, ) ) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append( + metrics_by_name[metric].get_sqla_col( + template_processor=template_processor + ) + ) else: raise QueryObjectValidationError( _("Metric '%(metric)s' does not exist", metric=metric) @@ -1463,14 +1486,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods col = metrics_exprs_by_expr.get(str(col), col) need_groupby = True elif col in columns_by_name: - gb_column_obj = columns_by_name[col] - if isinstance(gb_column_obj, dict): - col = self.get_sqla_col(gb_column_obj) - else: - col = gb_column_obj.get_sqla_col() + col = columns_by_name[col].get_sqla_col( + template_processor=template_processor + ) elif col in metrics_exprs_by_label: col = metrics_exprs_by_label[col] need_groupby = True + elif col in metrics_by_name: + col = metrics_by_name[col].get_sqla_col( + template_processor=template_processor + ) + need_groupby = True if isinstance(col, ColumnElement): orderby_exprs.append(col) @@ -1496,33 +1522,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] - if isinstance(table_col, dict): - outer = self.get_timestamp_expression( - column=table_col, - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - else: - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) + outer = table_col.get_timestamp_expression( + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) # if groupby field equals a selected column elif selected in columns_by_name: - if isinstance(columns_by_name[selected], dict): - outer = sa.column(f"{selected}") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = columns_by_name[selected].get_sqla_col() + outer = columns_by_name[selected].get_sqla_col( + template_processor=template_processor + ) else: - selected = self.validate_adhoc_subquery( + selected = validate_adhoc_subquery( selected, self.database_id, self.schema, ) - outer = sa.column(f"{selected}") + outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( @@ -1536,19 +1552,27 @@ class ExploreMixin: # pylint: disable=too-many-public-methods select_exprs.append(outer) elif columns: for selected in columns: - selected = self.validate_adhoc_subquery( - selected, + if is_adhoc_column(selected): + _sql = selected["sqlExpression"] + _column_label = selected["label"] + elif isinstance(selected, str): + _sql = selected + _column_label = selected + + selected = validate_adhoc_subquery( + _sql, self.database_id, self.schema, ) - if isinstance(columns_by_name[selected], dict): - select_exprs.append(sa.column(f"{selected}")) - else: - select_exprs.append( - columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) + select_exprs.append( + columns_by_name[selected].get_sqla_col( + template_processor=template_processor ) + if isinstance(selected, str) and selected in columns_by_name + else self.make_sqla_column_compatible( + literal_column(selected), _column_label + ) + ) metrics_exprs = [] if granularity: @@ -1559,57 +1583,41 @@ class ExploreMixin: # pylint: disable=too-many-public-methods col=granularity, ) ) - time_filters: List[Any] = [] + time_filters = [] if is_timeseries: - if isinstance(dttm_col, dict): - timestamp = self.get_timestamp_expression( - dttm_col, time_grain, template_processor=template_processor - ) - else: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns. - if db_engine_spec.time_secondary_columns: - if isinstance(dttm_col, dict): - dttm_col_name = dttm_col.get("column_name") - else: - dttm_col_name = dttm_col.column_name - - if ( - self.main_dttm_col in self.dttm_cols - and self.main_dttm_col != dttm_col_name - ): - if isinstance(self.main_dttm_col, dict): - time_filters.append( - self.get_time_filter( - self.main_dttm_col, - from_dttm, - to_dttm, - ) - ) - else: - time_filters.append( - columns_by_name[self.main_dttm_col].get_time_filter( - from_dttm, - to_dttm, - ) - ) - - if isinstance(dttm_col, dict): - time_filters.append(self.get_time_filter(dttm_col, from_dttm, to_dttm)) - else: - time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + if ( + db_engine_spec.time_secondary_columns + and self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col.column_name + ): + time_filters.append( + columns_by_name[self.main_dttm_col].get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) + time_filters.append( + dttm_col.get_time_filter( + start_dttm=from_dttm, + end_dttm=to_dttm, + template_processor=template_processor, + ) + ) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use # raw columns as custom SQL adhoc metric). - select_exprs = utils.remove_duplicates( + select_exprs = remove_duplicates( select_exprs + metrics_exprs, key=lambda x: x.name ) @@ -1619,7 +1627,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # Order by columns are "hidden" columns, some databases require them # always be present in SELECT if an aggregation function is used if not db_engine_spec.allows_hidden_ordeby_agg: - select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) + select_exprs = remove_duplicates(select_exprs + orderby_exprs) qry = sa.select(select_exprs) @@ -1637,18 +1645,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods flt_col = flt["col"] val = flt.get("val") op = flt["op"].upper() - col_obj: Optional["TableColumn"] = None + col_obj: Optional[TableColumn] = None sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col - elif utils.is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore + elif is_adhoc_column(flt_col): + sqla_col = self.adhoc_column_to_sqla(col=flt_col, template_processor=template_processor) # type: ignore else: col_obj = columns_by_name.get(flt_col) filter_grain = flt.get("grain") if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"): - if utils.get_column_name(flt_col) in removed_filters: + if get_column_name(flt_col) in removed_filters: # Skip generating SQLA filter when the jinja template handles it. continue @@ -1656,44 +1664,29 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if sqla_col is not None: pass elif col_obj and filter_grain: - if isinstance(col_obj, dict): - sqla_col = self.get_timestamp_expression( - col_obj, time_grain, template_processor=template_processor - ) - else: - sqla_col = col_obj.get_timestamp_expression( - time_grain=filter_grain, - template_processor=template_processor, - ) - elif col_obj and isinstance(col_obj, dict): - sqla_col = sa.column(col_obj.get("column_name")) + sqla_col = col_obj.get_timestamp_expression( + time_grain=filter_grain, template_processor=template_processor + ) elif col_obj: - sqla_col = col_obj.get_sqla_col() - - if col_obj and isinstance(col_obj, dict): - col_type = col_obj.get("type") - else: - col_type = col_obj.type if col_obj else None + sqla_col = col_obj.get_sqla_col( + template_processor=template_processor + ) + col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, - db_extra=self.database.get_extra(), # type: ignore +# db_extra=self.database.get_extra(), ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - if col_obj and isinstance(col_obj, dict): - col_advanced_data_type = "" - else: - col_advanced_data_type = ( - col_obj.advanced_data_type if col_obj else "" - ) + col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" if col_spec and not col_advanced_data_type: target_generic_type = col_spec.generic_type else: - target_generic_type = utils.GenericDataType.STRING + target_generic_type = GenericDataType.STRING eq = self.filter_values_handler( values=val, operator=op, @@ -1701,7 +1694,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), # type: ignore +# db_extra=self.database.get_extra(), ) if ( col_advanced_data_type != "" @@ -1757,7 +1750,14 @@ class ExploreMixin: # pylint: disable=too-many-public-methods elif op == utils.FilterOperator.IS_FALSE.value: where_clause_and.append(sqla_col.is_(False)) else: - if eq is None: + if ( + op + not in { + utils.FilterOperator.EQUALS.value, + utils.FilterOperator.NOT_EQUALS.value, + } + and eq is None + ): raise QueryObjectValidationError( _( "Must specify a value for filters " @@ -1791,23 +1791,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods extras=extras, ) where_clause_and.append( - self.get_time_filter( - time_col=col_obj, + col_obj.get_time_filter( start_dttm=_since, end_dttm=_until, + label=sqla_col.key, + template_processor=template_processor, ) ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - # todo(hugh): fix this w/ template_processor - # where_clause_and += self.get_sqla_row_level_filters(template_processor) + where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(f"{where}") + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1815,11 +1815,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods msg=ex.message, ) ) from ex + where = self._process_sql_expression( + expression=where, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor + ) where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(f"{having}") + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1827,9 +1833,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods msg=ex.message, ) ) from ex + having = self._process_sql_expression( + expression=having, + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore - qry = qry.where(self.get_fetch_values_predicate()) # type: ignore + + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore + qry = qry.where( + self.get_fetch_values_predicate(template_processor=template_processor) # type: ignore + ) if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1869,7 +1884,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods inner_groupby_exprs = [] inner_select_exprs = [] for gby_name, gby_obj in groupby_series_columns.items(): - label = utils.get_column_name(gby_name) + label = get_column_name(gby_name) inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__") inner_groupby_exprs.append(inner) inner_select_exprs.append(inner) @@ -1879,26 +1894,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods inner_time_filter = [] if dttm_col and not db_engine_spec.time_groupby_inline: - if isinstance(dttm_col, dict): - inner_time_filter = [ - self.get_time_filter( - dttm_col, - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - else: - inner_time_filter = [ - dttm_col.get_time_filter( - inner_from_dttm or from_dttm, - inner_to_dttm or to_dttm, - ) - ] - + inner_time_filter = [ + dttm_col.get_time_filter( + start_dttm=inner_from_dttm or from_dttm, + end_dttm=inner_to_dttm or to_dttm, + template_processor=template_processor, + ) + ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) subq = subq.group_by(*inner_groupby_exprs) ob = inner_main_metric_expr + if series_limit_metric: + ob = self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -1912,6 +1925,19 @@ class ExploreMixin: # pylint: disable=too-many-public-methods on_clause.append(gby_obj == sa.column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) + else: + if series_limit_metric: + orderby = [ + ( + self._get_series_orderby( + series_limit_metric=series_limit_metric, + metrics_by_name=metrics_by_name, + columns_by_name=columns_by_name, + template_processor=template_processor, + ), + not order_desc, + ) + ] # run prequery to get top groups prequery_obj = { @@ -1928,7 +1954,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "columns": columns, "order_desc": True, } - result = self.exc_query(prequery_obj) + + result = self.query(prequery_obj) prequeries.append(result.query) dimensions = [ c diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index babea35baf..5ccba99975 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -33,6 +33,7 @@ from sqlalchemy import ( DateTime, Enum, ForeignKey, + Hashable, Integer, Numeric, String, @@ -307,7 +308,7 @@ class Query( return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: + def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: return [] @property diff --git a/superset/result_set.py b/superset/result_set.py index 3d29673b9f..63d48b1e4b 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -70,9 +70,9 @@ def stringify_values(array: NDArray[Any]) -> NDArray[Any]: for obj in it: if na_obj := pd.isna(obj): # pandas <NA> type cannot be converted to string - obj[na_obj] = None # type: ignore + obj[na_obj] = None else: - obj[...] = stringify(obj) # type: ignore + obj[...] = stringify(obj) return result diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index 673c39ebf3..e2706345b1 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -57,10 +57,10 @@ def boxplot( """ def quartile1(series: Series) -> float: - return np.nanpercentile(series, 25, interpolation="midpoint") # type: ignore + return np.nanpercentile(series, 25, interpolation="midpoint") def quartile3(series: Series) -> float: - return np.nanpercentile(series, 75, interpolation="midpoint") # type: ignore + return np.nanpercentile(series, 75, interpolation="midpoint") if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY: @@ -99,8 +99,8 @@ def boxplot( return np.nanpercentile(series, low) else: - whisker_high = np.max # type: ignore - whisker_low = np.min # type: ignore + whisker_high = np.max + whisker_low = np.min def outliers(series: Series) -> Set[float]: above = series[series > whisker_high(series)] diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index 1026164e45..db783c4bed 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -85,7 +85,7 @@ def flatten( _columns = [] for series in df.columns.to_flat_index(): _cells = [] - for cell in series if is_sequence(series) else [series]: # type: ignore + for cell in series if is_sequence(series) else [series]: if pd.notnull(cell): # every cell should be converted to string and escape comma _cells.append(escape_separator(str(cell)))
