This is an automated email from the ASF dual-hosted git repository. lyndsi pushed a commit to branch lyndsi/sql-lab-new-explore-button-functionality-and-move-save-dataset-to-split-save-button in repository https://gitbox.apache.org/repos/asf/superset.git
commit fae0ac4542645b80e93dd987865ac8e0fa89aa18 Author: Hugh A. Miles II <[email protected]> AuthorDate: Wed Jun 8 13:27:47 2022 +0000 > fix integration point with frontend > allow for all records to be displayed > fix select with all columns queries > filters are now working --- superset/common/query_context_processor.py | 2 +- superset/models/helpers.py | 430 ++++++++++++++++++++--------- superset/views/core.py | 22 +- 3 files changed, 322 insertions(+), 132 deletions(-) diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 49977aaa55..e94af2d6ef 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -116,7 +116,7 @@ class QueryContextProcessor: and col != DTTM_ALIAS ) ] - breakpoint() + if invalid_columns: raise QueryObjectValidationError( _( diff --git a/superset/models/helpers.py b/superset/models/helpers.py index f7f9aeff33..9cfe56d551 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -67,16 +67,18 @@ from superset.jinja_context import ( get_template_processor, ) from superset.sql_parse import ( + has_table_query, extract_table_references, ParsedQuery, sanitize_clause, Table as TableName, ) + from superset.utils import core as utils VIRTUAL_TABLE_ALIAS = "virtual_table" - +config = app.config logger = logging.getLogger(__name__) CTE_ALIAS = "__cte" @@ -626,6 +628,7 @@ from sqlalchemy.sql.elements import ColumnElement, Label, literal_column from superset.exceptions import QueryObjectValidationError from superset.superset_typing import AdhocMetric, Metric, OrderBy, QueryObjectDict from superset.utils import core as utils +from superset.superset_typing import FilterValue, FilterValues, QueryObjectDict # todo(hugh): centralize where this code lives @@ -651,6 +654,15 @@ class ExploreMixin: to be used to power a chart inside /explore """ + sqla_aggregations = { + "COUNT_DISTINCT": lambda column_name: sa.func.COUNT(sa.distinct(column_name)), + "COUNT": sa.func.COUNT, + "SUM": sa.func.SUM, + "AVG": sa.func.AVG, + "MIN": sa.func.MIN, + "MAX": sa.func.MAX, + } + @property def data(self): return {"foo": "bar"} @@ -715,13 +727,78 @@ class ExploreMixin: sqla_col = sqla_col.label(label) sqla_col.key = label_expected return sqla_col + + def mutate_query_from_config(self, sql: str) -> str: + """Apply config's SQL_QUERY_MUTATOR + + Typically adds comments to the query with context""" + sql_query_mutator = config["SQL_QUERY_MUTATOR"] + if sql_query_mutator: + sql = sql_query_mutator( + sql, + user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. + security_manager=security_manager, + database=self.database, + ) + return sql + + @staticmethod + def _apply_cte(sql: str, cte: Optional[str]) -> str: + """ + Append a CTE before the SELECT statement if defined + + :param sql: SELECT statement + :param cte: CTE statement + :return: + """ + if cte: + sql = f"{cte}\n{sql}" + return sql + + def validate_adhoc_subquery( + self, + sql: str, + database_id: int, + default_schema: str, + ) -> str: + """ + Check if adhoc SQL contains sub-queries or nested sub-queries with table. + + If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS + predicates to it. + + :param sql: adhoc sql expression + :raise SupersetSecurityException if sql contains sub-queries or + nested sub-queries with table + """ + # pylint: disable=import-outside-toplevel + from superset import is_feature_enabled + + statements = [] + for statement in sqlparse.parse(sql): + if has_table_query(statement): + if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"): + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR, + message=_("Custom SQL fields cannot contain sub-queries."), + level=ErrorLevel.ERROR, + ) + ) + statement = insert_rls(statement, database_id, default_schema) + statements.append(statement) + + return ";\n".join(str(statement) for statement in statements) def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) - # sql = self._apply_cte(sql, sqlaq.cte) + sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) - # sql = self.mutate_query_from_config(sql) + sql = self.mutate_query_from_config(sql) + # from pprint import pprint + # pprint(sql) + # breakpoint() return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, labels_expected=sqlaq.labels_expected, @@ -883,6 +960,48 @@ class ExploreMixin: return from_clause, cte + def adhoc_metric_to_sqla( + self, + metric: AdhocMetric, + columns_by_name: Dict[str, Dict], + template_processor: Optional[BaseTemplateProcessor] = None, + ) -> ColumnElement: + """ + Turn an adhoc metric into a sqlalchemy column. + + :param dict metric: Adhoc metric definition + :param dict columns_by_name: Columns for the current table + :param template_processor: template_processor instance + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ + expression_type = metric.get("expressionType") + label = utils.get_metric_name(metric) + + if expression_type == utils.AdhocMetricExpressionType.SIMPLE: + metric_column = metric.get("column") or {} + column_name = cast(str, metric_column.get("column_name")) + table_column: Optional[TableColumn] = columns_by_name.get(column_name) + sqla_column = sa.column(column_name) + # todo(hughhh): understand how this works? + # if table_column: + # sqla_column = table_column.get_sqla_col() + # else: + # sqla_column = column(column_name) + sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) + elif expression_type == utils.AdhocMetricExpressionType.SQL: + expression = _process_sql_expression( + expression=metric["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_metric = literal_column(expression) + else: + raise QueryObjectValidationError("Adhoc metric expressionType is invalid") + + return self.make_sqla_column_compatible(sqla_metric, label) + @property def template_params_dict(self) -> Dict[Any, Any]: return {} @@ -1069,7 +1188,7 @@ class ExploreMixin: row_offset: Optional[int] = None, timeseries_limit: Optional[int] = None, timeseries_limit_metric: Optional[Metric] = None, - ) -> Any: + ) -> SqlaQuery: """Querying any sqla table from this common interface""" if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col @@ -1087,7 +1206,7 @@ class ExploreMixin: "time_column": granularity, "time_grain": time_grain, "to_dttm": to_dttm.isoformat() if to_dttm else None, - "table_columns": self.column_names, + "table_columns": [col.get('column_name') for col in self.columns], # [col.column_name for col in self.columns], "filter": filter, } columns = columns or [] @@ -1097,16 +1216,14 @@ class ExploreMixin: if is_timeseries and timeseries_limit: series_limit = timeseries_limit series_limit_metric = series_limit_metric or timeseries_limit_metric - template_kwargs.update(self.template_params_dict) # todo + template_kwargs.update(self.template_params_dict) extra_cache_keys: List[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys removed_filters: List[str] = [] applied_template_filters: List[str] = [] template_kwargs["removed_filters"] = removed_filters template_kwargs["applied_filters"] = applied_template_filters - template_processor = ( - None # self.get_template_processor(**template_kwargs) #todo - ) + template_processor = None # self.get_template_processor(**template_kwargs) db_engine_spec = self.db_engine_spec prequeries: List[str] = [] orderby = orderby or [] @@ -1117,39 +1234,39 @@ class ExploreMixin: if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - # todo(hugh): fix this - columns_by_name = { - col.get('column_name'): col for col in self.columns + columns_by_name: Dict[str, TableColumn] = { + col.get('column_name'): col for col in self.columns # col.column_name: col for col in self.columns } - - # todo(hugh): how are we handling metrics - # metrics_by_name: Dict[str, Column] = { # todo column vs metric? - # m.metric_name: m for m in self.metrics - # } - metrics_by_name: Dict[str, Column] = {} - - # if not granularity and is_timeseries: - # raise QueryObjectValidationError( - # _( - # "Datetime column not provided as part table configuration " - # "and is required by this type of chart" - # ) - # ) - # if not metrics and not columns and not groupby: - # raise QueryObjectValidationError(_("Empty query?")) + + metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + + if not granularity and is_timeseries: + raise QueryObjectValidationError( + _( + "Datetime column not provided as part table configuration " + "and is required by this type of chart" + ) + ) + if not metrics and not columns and not groupby: + raise QueryObjectValidationError(_("Empty query?")) metrics_exprs: List[ColumnElement] = [] - # for metric in metrics: - # if utils.is_adhoc_metric(metric): - # assert isinstance(metric, dict) - # # metrics_exprs.append( - # # self.adhoc_metric_to_sqla(metric, columns_by_name)) - # elif isinstance(metric, str) and metric in metrics_by_name: - # metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) - # else: - # raise QueryObjectValidationError( - # _("Metric '%(metric)s' does not exist", metric=metric) - # ) + for metric in metrics: + if utils.is_adhoc_metric(metric): + assert isinstance(metric, dict) + metrics_exprs.append( + self.adhoc_metric_to_sqla( + metric=metric, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) + ) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) + else: + raise QueryObjectValidationError( + _("Metric '%(metric)s' does not exist", metric=metric) + ) if metrics_exprs: main_metric_expr = metrics_exprs[0] @@ -1169,10 +1286,16 @@ class ExploreMixin: col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): col = cast(AdhocMetric, col) + if col.get("sqlExpression"): + col["sqlExpression"] = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists - # todo(hugh): figure out if we should have metrics - # col = self.adhoc_metric_to_sqla(col, columns_by_name) + col = self.adhoc_metric_to_sqla(col, columns_by_name) # if the adhoc metric has been defined before # use the existing instance. col = metrics_exprs_by_expr.get(str(col), col) @@ -1186,14 +1309,13 @@ class ExploreMixin: col = metrics_by_name[col].get_sqla_col() need_groupby = True - # todo(hugh): fix this - # if isinstance(col, ColumnElement): - # orderby_exprs.append(col) - # else: - # # Could not convert a column reference to valid ColumnElement - # raise QueryObjectValidationError( - # _("Unknown column used in orderby: %(col)s", col=orig_col) - # ) + if isinstance(col, ColumnElement): + orderby_exprs.append(col) + else: + # Could not convert a column reference to valid ColumnElement + raise QueryObjectValidationError( + _("Unknown column used in orderby: %(col)s", col=orig_col) + ) select_exprs: List[Union[Column, Label]] = [] groupby_all_columns = {} @@ -1207,71 +1329,90 @@ class ExploreMixin: # dedup columns while preserving order columns = groupby or columns for selected in columns: - if isinstance(selected, str): + if isinstance(selected, str): # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] outer = table_col.get_timestamp_expression( time_grain=time_grain, label=selected, - # template_processor=template_processor, + template_processor=template_processor, ) # if groupby field equals a selected column - # elif selected in columns_by_name: - # outer = columns_by_name[selected].get_sqla_col() - # else: - # outer = literal_column(f"({selected})") - # outer = self.make_sqla_column_compatible(outer, selected) + elif selected in columns_by_name: + if isinstance(columns_by_name[selected], dict): + outer = literal_column(f"({selected})") + outer = self.make_sqla_column_compatible(outer, selected) + else: + outer = columns_by_name[selected].get_sqla_col() + else: + selected = self.validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) + outer = literal_column(f"({selected})") + outer = self.make_sqla_column_compatible(outer, selected) else: outer = self.adhoc_column_to_sqla( - col=selected, # template_processor=template_processor + col=selected, template_processor=template_processor ) - # groupby_all_columns[outer.name] = outer - # if not series_column_names or outer.name in series_column_names: - # groupby_series_columns[outer.name] = outer - # select_exprs.append(outer) + groupby_all_columns[outer.name] = outer + if not series_column_names or outer.name in series_column_names: + groupby_series_columns[outer.name] = outer + select_exprs.append(outer) elif columns: for selected in columns: - # select_exprs.append( - # columns_by_name[selected].get_sqla_col() - # if selected in columns_by_name - # else self.make_sqla_column_compatible(literal_column(selected)) - # ) - select_exprs.append(selected) + selected = self.validate_adhoc_subquery( + selected, + self.database_id, + self.schema, + ) + if isinstance(columns_by_name[selected], dict): + select_exprs.append( + literal_column(f"({selected})") + ) + else: + select_exprs.append( + columns_by_name[selected].get_sqla_col() + if selected in columns_by_name + else self.make_sqla_column_compatible(literal_column(selected)) + ) metrics_exprs = [] - # todo(hugh): fix this - # if granularity: - # if granularity not in columns_by_name or not dttm_col: - # raise QueryObjectValidationError( - # _( - # 'Time column "%(col)s" does not exist in dataset', - # col=granularity, - # ) - # ) - # time_filters = [] - - # if is_timeseries: - # timestamp = dttm_col.get_timestamp_expression( - # time_grain=time_grain, template_processor=template_processor - # ) - # # always put timestamp as the first column - # select_exprs.insert(0, timestamp) - # groupby_all_columns[timestamp.name] = timestamp - - # # Use main dttm column to support index with secondary dttm columns. - # if ( - # db_engine_spec.time_secondary_columns - # and self.main_dttm_col in self.dttm_cols - # and self.main_dttm_col != dttm_col.column_name - # ): - # time_filters.append( - # columns_by_name[self.main_dttm_col].get_time_filter( - # from_dttm, - # to_dttm, - # ) - # ) - # time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) + if granularity: + if granularity not in columns_by_name or not dttm_col: + raise QueryObjectValidationError( + _( + 'Time column "%(col)s" does not exist in dataset', + col=granularity, + ) + ) + time_filters = [] + + if is_timeseries: + timestamp = dttm_col.get_timestamp_expression( + time_grain=time_grain, template_processor=template_processor + ) + # always put timestamp as the first column + select_exprs.insert(0, timestamp) + groupby_all_columns[timestamp.name] = timestamp + + # Use main dttm column to support index with secondary dttm columns. + if ( + db_engine_spec.time_secondary_columns + and self.main_dttm_col in self.dttm_cols + and self.main_dttm_col != dttm_col.column_name + ): + pass + # todo(hughhh): fix time filter + # time_filters.append( + # columns_by_name[self.main_dttm_col].get_time_filter( + # from_dttm, + # to_dttm, + # ) + # ) + # time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) # Always remove duplicates by column name, as sometimes `metrics_exprs` # can have the same name as a groupby column (e.g. when users use @@ -1288,10 +1429,8 @@ class ExploreMixin: if not db_engine_spec.allows_hidden_ordeby_agg: select_exprs = utils.remove_duplicates(select_exprs + orderby_exprs) - qry = sa.select([sa.column("YEAR")]) + qry = sa.select(select_exprs) - # todo(hugh) fix templating - # tbl, cte = self.get_from_clause(template_processor) tbl, cte = self.get_from_clause(template_processor) if groupby_all_columns: @@ -1306,8 +1445,8 @@ class ExploreMixin: flt_col = flt["col"] val = flt.get("val") op = flt["op"].upper() - col_obj: Optional[Column] = None - sqla_col: Optional[sa.Column] = None + col_obj: Optional[TableColumn] = None + sqla_col: Optional[Column] = None if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif utils.is_adhoc_column(flt_col): @@ -1328,35 +1467,79 @@ class ExploreMixin: sqla_col = col_obj.get_timestamp_expression( time_grain=filter_grain, template_processor=template_processor ) + elif col_obj and isinstance(col_obj, dict): + sqla_col = sa.column(col_obj.get('column_name')) elif col_obj: sqla_col = col_obj.get_sqla_col() + + if col_obj and isinstance(col_obj, dict): + col_type = col_obj.get('type') + else: + col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( - col_obj.type if col_obj else None + native_type=col_type, + db_extra=self.database.get_extra(), ) is_list_target = op in ( utils.FilterOperator.IN.value, utils.FilterOperator.NOT_IN.value, ) - if col_spec: - target_type = col_spec.generic_type + + if col_obj and isinstance(col_obj, dict): + col_advanced_data_type = "" else: - target_type = GenericDataType.STRING + col_advanced_data_type = col_obj.advanced_data_type if col_obj else "" + + if col_spec and not col_advanced_data_type: + target_generic_type = col_spec.generic_type + else: + target_generic_type = utils.GenericDataType.STRING eq = self.filter_values_handler( values=val, - target_column_type=target_type, + target_generic_type=target_generic_type, + target_native_type=col_type, is_list_target=is_list_target, + db_engine_spec=db_engine_spec, + db_extra=self.database.get_extra(), ) - if is_list_target: + if ( + col_advanced_data_type != "" + and feature_flag_manager.is_feature_enabled( + "ENABLE_ADVANCED_DATA_TYPES" + ) + and col_advanced_data_type in ADVANCED_DATA_TYPES + ): + values = eq if is_list_target else [eq] # type: ignore + bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ + col_advanced_data_type + ].translate_type( + { + "type": col_advanced_data_type, + "values": values, + } + ) + if bus_resp["error_message"]: + raise AdvancedDataTypeResponseError( + _(bus_resp["error_message"]) + ) + + where_clause_and.append( + ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( + sqla_col, op, bus_resp["values"] + ) + ) + elif is_list_target: assert isinstance(eq, (tuple, list)) if len(eq) == 0: raise QueryObjectValidationError( _("Filter value list cannot be empty") ) - if None in eq: - eq = [x for x in eq if x is not None] + if len(eq) > len( + eq_without_none := [x for x in eq if x is not None] + ): is_null_cond = sqla_col.is_(None) if eq: - cond = or_(is_null_cond, sqla_col.in_(eq)) + cond = or_(is_null_cond, sqla_col.in_(eq_without_none)) else: cond = is_null_cond else: @@ -1400,13 +1583,13 @@ class ExploreMixin: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) ) - if is_feature_enabled("ROW_LEVEL_SECURITY"): - where_clause_and += self._get_sqla_row_level_filters(template_processor) + # todo(hugh): fix this + # where_clause_and += self.get_sqla_row_level_filters(template_processor) if extras: where = extras.get("where") if where: try: - where = template_processor.process_template(where) + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1414,11 +1597,11 @@ class ExploreMixin: msg=ex.message, ) ) from ex - where_clause_and += [self.text(f"({where})")] + where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(having) + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1426,13 +1609,10 @@ class ExploreMixin: msg=ex.message, ) ) from ex - having_clause_and += [self.text(f"({having})")] + having_clause_and += [self.text(having)] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: - time_filters = ( - [] - ) # todo(hugh): remove this once time filters are actually set qry = qry.where(and_(*(time_filters + where_clause_and))) else: qry = qry.where(and_(*where_clause_and)) @@ -1452,7 +1632,7 @@ class ExploreMixin: and col.name in [select_col.name for select_col in select_exprs] ): col = literal_column(col.name) - direction = asc if ascending else desc + direction = sa.asc if ascending else sa.desc qry = qry.order_by(direction(col)) if row_limit: @@ -1495,7 +1675,7 @@ class ExploreMixin: ob = self._get_series_orderby( series_limit_metric, metrics_by_name, columns_by_name ) - direction = desc if order_desc else asc + direction = sa.desc if order_desc else sa.asc subq = subq.order_by(direction(ob)) subq = subq.limit(series_limit) @@ -1558,7 +1738,7 @@ class ExploreMixin: ) label = "rowcount" col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) - qry = select([col]).select_from(qry.alias("rowcount_qry")) + qry = sa.select([col]).select_from(qry.alias("rowcount_qry")) labels_expected = [label] return SqlaQuery( diff --git a/superset/views/core.py b/superset/views/core.py index 75848a29c9..42681edb11 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -777,24 +777,29 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods initial_form_data = json.loads(value) if value else {} from superset.dao.datasource.dao import DatasourceDAO - from superset.utils.core import DatasourceType from superset.models.helpers import ExploreMixin + from superset.utils.core import DatasourceType # Handle SIP-68 Models or explore view # API will always use /explore/<datasource_type>/<int:datasource_id>/ to query # new models to power any viz in explore - if datasource_id and datasource_type: + datasource_id = request.args.get('datasource_id', datasource_id) + datasource_type = request.args.get('datasource_type', datasource_type) + + if datasource_id and datasource_type: # 1. Query datasource object by type and id datasource = DatasourceDAO.get_datasource( session=db.session, datasource_type=DatasourceType(datasource_type), datasource_id=datasource_id, ) - + # 2. Verify that it's an ExploreMixin if isinstance(datasource, ExploreMixin): # Handle Query object bootstrap - datasource_name = datasource.name if datasource else _("[Missing Dataset]") + datasource_name = ( + datasource.name if datasource else _("[Missing Dataset]") + ) form_data, slc = get_form_data( use_slice_data=True, initial_form_data=initial_form_data ) @@ -827,13 +832,17 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods if action == "overwrite" and not slice_overwrite_perm: return json_error_response( - _("You don't have the rights to ") + _("alter this ") + _("chart"), + _("You don't have the rights to ") + + _("alter this ") + + _("chart"), status=403, ) if action == "saveas" and not slice_add_perm: return json_error_response( - _("You don't have the rights to ") + _("create a ") + _("chart"), + _("You don't have the rights to ") + + _("create a ") + + _("chart"), status=403, ) @@ -875,6 +884,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods datasource_data["metrics"] = datasource.extra.get("metrics", []) datasource_data["id"] = datasource_id datasource_data["type"] = datasource_type + datasource_data["name"] = datasource.sql bootstrap_data = { "can_add": slice_add_perm,
