This is an automated email from the ASF dual-hosted git repository. lyndsi pushed a commit to branch lyndsi/sql-lab-new-explore-button-functionality-and-move-save-dataset-to-split-save-button in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8d087f2d413a54698536c1585920b5b2b6d6cc24 Author: Hugh A. Miles II <[email protected]> AuthorDate: Fri Jun 24 12:46:39 2022 +0000 patch for pre-commit --- superset/models/helpers.py | 160 +++++++++++++++++++++++++++++---------------- superset/models/sql_lab.py | 14 +--- superset/utils/core.py | 2 +- 3 files changed, 106 insertions(+), 70 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 73f10741fb..62eec0ec57 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -29,15 +29,18 @@ from typing import ( Dict, Hashable, List, + Mapping, NamedTuple, Optional, Set, + Text, Tuple, Type, TYPE_CHECKING, Union, ) +import dateutil.parser import humanize import numpy as np import pandas as pd @@ -62,10 +65,15 @@ from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType from superset import app, db, is_feature_enabled, security_manager +from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus from superset.constants import EMPTY_STRING, NULL_STRING from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import SupersetSecurityException +from superset.exceptions import ( + AdvancedDataTypeResponseError, + QueryClauseValidationException, + SupersetSecurityException, +) from superset.extensions import feature_flag_manager from superset.jinja_context import ( BaseTemplateProcessor, @@ -716,7 +724,7 @@ class ExploreMixin: raise NotImplementedError() @property - def main_dttm_col(self) -> str: + def main_dttm_col(self) -> Optional[str]: raise NotImplementedError() @property @@ -751,6 +759,26 @@ class ExploreMixin: def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: raise NotImplementedError() + def _process_sql_expression( # type: ignore + expression: Optional[str], + database_id: int, + schema: str, + template_processor: Optional[BaseTemplateProcessor], + ) -> Optional[str]: + if template_processor and expression: + expression = template_processor.process_template(expression) + if expression: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + return expression + def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None ) -> ColumnElement: @@ -832,13 +860,10 @@ class ExploreMixin: def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) - sql = self.database.compile_sqla_query(sqlaq.sqla_query) + sql = self.database.compile_sqla_query(sqlaq.sqla_query) # type: ignore sql = self._apply_cte(sql, sqlaq.cte) sql = sqlparse.format(sql, reindent=True) sql = self.mutate_query_from_config(sql) - # from pprint import pprint - # pprint(sql) - # breakpoint() return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, labels_expected=sqlaq.labels_expected, @@ -846,6 +871,43 @@ class ExploreMixin: sql=sql, ) + def _normalize_prequery_result_type( + self, + row: pd.Series, + dimension: str, + columns_by_name: Dict[str, "TableColumn"], + ) -> Union[str, int, float, bool, Text]: + """ + Convert a prequery result type to its equivalent Python type. + + Some databases like Druid will return timestamps as strings, but do not perform + automatic casting when comparing these strings to a timestamp. For cases like + this we convert the value via the appropriate SQL transform. + + :param row: A prequery record + :param dimension: The dimension name + :param columns_by_name: The mapping of columns by name + :return: equivalent primitive python type + """ + + value = row[dimension] + + if isinstance(value, np.generic): + value = value.item() + + column_ = columns_by_name[dimension] + db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore + + if column_.type and column_.is_temporal and isinstance(value, str): + sql = self.db_engine_spec.convert_dttm( + column_.type, dateutil.parser.parse(value), db_extra=db_extra + ) + + if sql: + value = self.text(sql) + + return value + def make_orderby_compatible( self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] ) -> None: @@ -883,49 +945,40 @@ class ExploreMixin: errors = None error_message = None - # def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: - # """ - # Some engines change the case or generate bespoke column names, either by - # default or due to lack of support for aliasing. This function ensures that - # the column names in the DataFrame correspond to what is expected by - # the viz components. - # Sometimes a query may also contain only order by columns that are not used - # as metrics or groupby columns, but need to present in the SQL `select`, - # filtering by `labels_expected` make sure we only return columns users want. - # :param df: Original DataFrame returned by the engine - # :return: Mutated DataFrame - # """ - # labels_expected = query_str_ext.labels_expected - # if df is not None and not df.empty: - # if len(df.columns) < len(labels_expected): - # raise QueryObjectValidationError( - # _("Db engine did not return all queried columns") - # ) - # if len(df.columns) > len(labels_expected): - # df = df.iloc[:, 0: len(labels_expected)] - # df.columns = labels_expected - # return df + def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: + """ + Some engines change the case or generate bespoke column names, either by + default or due to lack of support for aliasing. This function ensures that + the column names in the DataFrame correspond to what is expected by + the viz components. + Sometimes a query may also contain only order by columns that are not used + as metrics or groupby columns, but need to present in the SQL `select`, + filtering by `labels_expected` make sure we only return columns users want. + :param df: Original DataFrame returned by the engine + :return: Mutated DataFrame + """ + labels_expected = query_str_ext.labels_expected + if df is not None and not df.empty: + if len(df.columns) < len(labels_expected): + raise QueryObjectValidationError( + _("Db engine did not return all queried columns") + ) + if len(df.columns) > len(labels_expected): + df = df.iloc[:, 0 : len(labels_expected)] + df.columns = labels_expected + return df try: - # todo(hugh) fix this - # df = self.database.get_df( - # sql, self.schema, mutator=assign_column_label) - df = self.database.get_df(sql, self.schema) + df = self.database.get_df(sql, self.schema, mutator=assign_column_label) # type: ignore except Exception as ex: # pylint: disable=broad-except df = pd.DataFrame() status = QueryStatus.FAILED logger.warning( "Query %s on schema %s failed", sql, self.schema, exc_info=True ) - # todo(hugh): how are we handling errors - # db_engine_spec = self.db_engine_spec - # errors = [ - # dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex) - # ] error_message = utils.error_msg_from_exception(ex) return QueryResult( - # applied_template_filters=query_str_ext.applied_template_filters, status=status, df=df, duration=datetime.now() - qry_start_dttm, @@ -1016,16 +1069,11 @@ class ExploreMixin: if expression_type == utils.AdhocMetricExpressionType.SIMPLE: metric_column = metric.get("column") or {} column_name = cast(str, metric_column.get("column_name")) - table_column: Optional[TableColumn] = columns_by_name.get(column_name) + table_column: Optional[Dict[str, Any]] = columns_by_name.get(column_name) sqla_column = sa.column(column_name) - # todo(hughhh): understand how this works? - # if table_column: - # sqla_column = table_column.get_sqla_col() - # else: - # sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - expression = _process_sql_expression( + expression = _process_sql_expression( # type: ignore expression=metric["sqlExpression"], database_id=self.database_id, schema=self.schema, @@ -1295,7 +1343,7 @@ class ExploreMixin: metrics_exprs.append( self.adhoc_metric_to_sqla( metric=metric, - columns_by_name=columns_by_name, + columns_by_name=columns_by_name, # type: ignore template_processor=template_processor, ) ) @@ -1325,7 +1373,7 @@ class ExploreMixin: if isinstance(col, dict): col = cast(AdhocMetric, col) if col.get("sqlExpression"): - col["sqlExpression"] = _process_sql_expression( + col["sqlExpression"] = _process_sql_expression( # type: ignore expression=col["sqlExpression"], database_id=self.database_id, schema=self.schema, @@ -1333,7 +1381,7 @@ class ExploreMixin: ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists - col = self.adhoc_metric_to_sqla(col, columns_by_name) + col = self.adhoc_metric_to_sqla(col, columns_by_name) # type: ignore # if the adhoc metric has been defined before # use the existing instance. col = metrics_exprs_by_expr.get(str(col), col) @@ -1486,7 +1534,7 @@ class ExploreMixin: if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col: col_obj = dttm_col elif utils.is_adhoc_column(flt_col): - sqla_col = self.adhoc_column_to_sqla(flt_col) + sqla_col = self.adhoc_column_to_sqla(flt_col) # type: ignore else: col_obj = columns_by_name.get(flt_col) filter_grain = flt.get("grain") @@ -1514,7 +1562,7 @@ class ExploreMixin: col_type = col_obj.type if col_obj else None col_spec = db_engine_spec.get_column_spec( native_type=col_type, - db_extra=self.database.get_extra(), + db_extra=self.database.get_extra(), # type: ignore ) is_list_target = op in ( utils.FilterOperator.IN.value, @@ -1538,7 +1586,7 @@ class ExploreMixin: target_native_type=col_type, is_list_target=is_list_target, db_engine_spec=db_engine_spec, - db_extra=self.database.get_extra(), + db_extra=self.database.get_extra(), # type: ignore ) if ( col_advanced_data_type != "" @@ -1627,7 +1675,7 @@ class ExploreMixin: where = extras.get("where") if where: try: - where = template_processor.process_template(f"({where})") + where = template_processor.process_template(f"({where})") # type: ignore except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1639,7 +1687,7 @@ class ExploreMixin: having = extras.get("having") if having: try: - having = template_processor.process_template(f"({having})") + having = template_processor.process_template(f"({having})") # type: ignore except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1648,8 +1696,8 @@ class ExploreMixin: ) ) from ex having_clause_and += [self.text(having)] - if apply_fetch_values_predicate and self.fetch_values_predicate: - qry = qry.where(self.get_fetch_values_predicate()) + if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore + qry = qry.where(self.get_fetch_values_predicate()) # type: ignore if granularity: qry = qry.where(and_(*(time_filters + where_clause_and))) else: @@ -1755,7 +1803,7 @@ class ExploreMixin: "order_desc": True, } - result = self.query(prequery_obj) # ignore: typing + result = self.query(prequery_obj) # type: ignore prequeries.append(result.query) dimensions = [ c diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 6c6e10b514..d610410ded 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -225,18 +225,6 @@ class Query(Model, ExtraJSONMixin, ExploreMixin): "database": {"id": self.database_id, "backend": self.database.backend}, } - @property - def data(self) -> Dict[str, Any]: - return { - "columns": self.columns, - "metrics": [], - "id": self.id, - "type": self.type, - "sql": self.sql, - "owners": self.owners_data, - "database": {"id": self.database_id, "backend": self.database.backend}, - } - def raise_for_access(self) -> None: """ Raise an exception if the user cannot access the resource. @@ -282,7 +270,7 @@ class Query(Model, ExtraJSONMixin, ExploreMixin): def main_dttm_col(self) -> Optional[str]: for col in self.columns: if col.get("is_dttm"): - return col.get("column_name") + return col.get("column_name") # type: ignore return None @property diff --git a/superset/utils/core.py b/superset/utils/core.py index 0ecff893c5..f76995fb32 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1739,7 +1739,7 @@ def get_time_filter_status( # pylint: disable=too-many-branches # todo(hugh): fix this # temporal_columns = {col.column_name for col in datasource.columns if col.is_dttm} - temporal_columns = {} + temporal_columns: Dict[str, Any] = {} applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)
